Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
baa74326
Commit
baa74326
authored
Sep 19, 2019
by
LysandreJik
Browse files
Stride + tests + small fixes
parent
c10c7d59
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
7 deletions
+11
-7
pytorch_transformers/tests/tokenization_tests_commons.py
pytorch_transformers/tests/tokenization_tests_commons.py
+4
-2
pytorch_transformers/tokenization_distilbert.py
pytorch_transformers/tokenization_distilbert.py
+0
-1
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+7
-4
No files found.
pytorch_transformers/tests/tokenization_tests_commons.py
View file @
baa74326
...
@@ -217,16 +217,18 @@ class CommonTestCases:
...
@@ -217,16 +217,18 @@ class CommonTestCases:
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
seq_0
=
"This is a sentence to be encoded."
stride
=
2
sequence
=
tokenizer
.
encode
(
seq_0
)
sequence
=
tokenizer
.
encode
(
seq_0
)
num_added_tokens
=
tokenizer
.
num_added_tokens
()
num_added_tokens
=
tokenizer
.
num_added_tokens
()
total_length
=
len
(
sequence
)
+
num_added_tokens
total_length
=
len
(
sequence
)
+
num_added_tokens
information
=
tokenizer
.
encode_plus
(
seq_0
,
max_length
=
total_length
-
2
,
add_special_tokens
=
True
)
information
=
tokenizer
.
encode_plus
(
seq_0
,
max_length
=
total_length
-
2
,
add_special_tokens
=
True
,
stride
=
stride
)
truncated_sequence
=
information
[
"sequence"
]
truncated_sequence
=
information
[
"sequence"
]
overflowing_tokens
=
information
[
"overflowing_tokens"
]
overflowing_tokens
=
information
[
"overflowing_tokens"
]
assert
len
(
overflowing_tokens
)
==
2
assert
len
(
overflowing_tokens
)
==
2
+
stride
assert
overflowing_tokens
==
sequence
[
-
(
2
+
stride
):]
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sequence
(
sequence
[:
-
2
])
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sequence
(
sequence
[:
-
2
])
...
...
pytorch_transformers/tokenization_distilbert.py
View file @
baa74326
...
@@ -76,6 +76,5 @@ class DistilBertTokenizer(BertTokenizer):
...
@@ -76,6 +76,5 @@ class DistilBertTokenizer(BertTokenizer):
| first sequence | second sequence
| first sequence | second sequence
"""
"""
sep
=
[
self
.
sep_token_id
]
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
return
len
(
self
.
encode
(
sequence_0
)
+
sep
)
*
[
0
]
+
len
(
self
.
encode
(
sequence_1
))
*
[
1
]
return
len
(
self
.
encode
(
sequence_0
)
+
sep
)
*
[
0
]
+
len
(
self
.
encode
(
sequence_1
))
*
[
1
]
pytorch_transformers/tokenization_utils.py
View file @
baa74326
...
@@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
...
@@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
logger
.
warning
(
"No special tokens were added. The two sequences have been concatenated."
)
logger
.
warning
(
"No special tokens were added. The two sequences have been concatenated."
)
return
first_sentence_tokens
+
second_sentence_tokens
return
first_sentence_tokens
+
second_sentence_tokens
def
encode_plus
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
output_mask
=
False
,
max_length
=
None
,
**
kwargs
):
def
encode_plus
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
output_mask
=
False
,
max_length
=
None
,
stride
=
0
,
**
kwargs
):
"""
"""
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
...
@@ -735,6 +735,9 @@ class PreTrainedTokenizer(object):
...
@@ -735,6 +735,9 @@ class PreTrainedTokenizer(object):
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second.
and 1 for the second.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
**kwargs: passed to the `self.tokenize()` method
**kwargs: passed to the `self.tokenize()` method
"""
"""
...
@@ -745,13 +748,13 @@ class PreTrainedTokenizer(object):
...
@@ -745,13 +748,13 @@ class PreTrainedTokenizer(object):
if
add_special_tokens
:
if
add_special_tokens
:
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
if
max_length
:
if
max_length
:
information
[
"overflowing_tokens"
]
=
sequence_tokens
[
max_length
-
n_added_tokens
:]
information
[
"overflowing_tokens"
]
=
sequence_tokens
[
max_length
-
n_added_tokens
-
stride
:]
sequence_tokens
=
sequence_tokens
[:
max_length
-
n_added_tokens
]
sequence_tokens
=
sequence_tokens
[:
max_length
-
n_added_tokens
]
sequence
=
self
.
add_special_tokens_single_sequence
(
sequence_tokens
)
sequence
=
self
.
add_special_tokens_single_sequence
(
sequence_tokens
)
else
:
else
:
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
if
max_length
:
if
max_length
:
information
[
"overflowing_tokens"
]
=
sequence_tokens
[
max_length
:]
information
[
"overflowing_tokens"
]
=
sequence_tokens
[
max_length
-
stride
:]
sequence_tokens
=
sequence_tokens
[:
max_length
]
sequence_tokens
=
sequence_tokens
[:
max_length
]
sequence
=
sequence_tokens
sequence
=
sequence_tokens
...
@@ -788,7 +791,7 @@ class PreTrainedTokenizer(object):
...
@@ -788,7 +791,7 @@ class PreTrainedTokenizer(object):
sequence
=
first_sentence_tokens
+
second_sentence_tokens
sequence
=
first_sentence_tokens
+
second_sentence_tokens
if
max_length
:
if
max_length
:
information
[
"overflowing_tokens"
]
=
sequence
[
max_length
:]
information
[
"overflowing_tokens"
]
=
sequence
[
max_length
-
stride
:]
sequence
=
sequence
[:
max_length
]
sequence
=
sequence
[:
max_length
]
if
output_mask
:
if
output_mask
:
information
[
"mask"
]
=
[
0
]
*
len
(
sequence
)
information
[
"mask"
]
=
[
0
]
*
len
(
sequence
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment