Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
af23b626
Commit
af23b626
authored
Sep 11, 2019
by
LysandreJik
Browse files
Max encoding length + corresponding tests
parent
c4d4f3ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
3 deletions
+48
-3
pytorch_transformers/tests/tokenization_tests_commons.py
pytorch_transformers/tests/tokenization_tests_commons.py
+29
-0
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+19
-3
No files found.
pytorch_transformers/tests/tokenization_tests_commons.py
View file @
af23b626
...
...
@@ -211,3 +211,32 @@ class CommonTestCases:
# Method is implemented (e.g. not GPT-2)
if
len
(
attached_sequences
)
!=
2
:
assert
tokenizer
.
num_added_tokens
(
pair
=
True
)
==
len
(
attached_sequences
)
-
sum
([
len
(
seq
)
for
seq
in
sequences
])
def
test_maximum_encoding_length_single_input
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
sequence
=
tokenizer
.
encode
(
seq_0
)
num_added_tokens
=
tokenizer
.
num_added_tokens
()
total_length
=
len
(
sequence
)
+
num_added_tokens
truncated_sequence
=
tokenizer
.
encode
(
seq_0
,
max_length
=
total_length
-
2
,
add_special_tokens
=
True
)
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sentence
(
sequence
[:
-
2
])
def
test_maximum_encoding_length_pair_input
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
seq_1
=
"This is another sentence to be encoded."
sequence
=
tokenizer
.
encode
(
seq_0
,
seq_1
,
add_special_tokens
=
True
)
truncated_second_sequence
=
tokenizer
.
add_special_tokens_sentences_pair
(
tokenizer
.
encode
(
seq_0
),
tokenizer
.
encode
(
seq_1
)[:
-
2
]
)
truncated_sequence
=
tokenizer
.
encode
(
seq_0
,
seq_1
,
max_length
=
len
(
sequence
)
-
2
,
add_special_tokens
=
True
)
assert
len
(
truncated_sequence
)
==
len
(
sequence
)
-
2
assert
truncated_sequence
==
truncated_second_sequence
pytorch_transformers/tokenization_utils.py
View file @
af23b626
...
...
@@ -693,7 +693,7 @@ class PreTrainedTokenizer(object):
def
_convert_token_to_id
(
self
,
token
):
raise
NotImplementedError
def
encode
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
output_mask
=
False
,
**
kwargs
):
def
encode
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
output_mask
=
False
,
max_length
=
None
,
**
kwargs
):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...
...
@@ -706,20 +706,36 @@ class PreTrainedTokenizer(object):
to their model.
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
**kwargs: passed to the `self.tokenize()` method
"""
if
text_pair
is
None
:
if
add_special_tokens
:
return
self
.
add_special_tokens_single_sentence
(
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
)))
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
if
max_length
:
sequence_tokens
=
sequence_tokens
[:
max_length
-
self
.
num_added_tokens
()]
return
self
.
add_special_tokens_single_sentence
(
sequence_tokens
)
else
:
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
ids
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
return
ids
[:
max_length
]
if
max_length
!=
-
1
else
ids
first_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text
,
**
kwargs
)]
second_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text_pair
,
**
kwargs
)]
if
add_special_tokens
:
if
max_length
:
if
len
(
first_sentence_tokens
)
+
self
.
num_added_tokens
(
pair
=
True
)
>=
max_length
:
logger
.
warning
(
"The first sequence is longer than the maximum specified length. This sequence will not be truncated."
)
else
:
if
len
(
second_sentence_tokens
)
+
len
(
first_sentence_tokens
)
+
self
.
num_added_tokens
(
pair
=
True
)
>
max_length
:
second_sentence_tokens
=
second_sentence_tokens
[:
max_length
-
len
(
first_sentence_tokens
)
-
self
.
num_added_tokens
(
pair
=
True
)]
return
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
,
output_mask
)
else
:
if
max_length
:
first_sentence_tokens
=
first_sentence_tokens
[:
max_length
]
second_sentence_tokens
=
second_sentence_tokens
[:
max_length
]
if
output_mask
:
logger
.
warning
(
"Can't output mask if you're not joining two sequences."
)
return
first_sentence_tokens
,
second_sentence_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment