Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7c789c33
Commit
7c789c33
authored
Sep 30, 2019
by
LysandreJik
Browse files
Always truncate argument in the encode method
parent
7af07779
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
12 deletions
+48
-12
transformers/tests/tokenization_tests_commons.py
transformers/tests/tokenization_tests_commons.py
+17
-0
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+31
-12
No files found.
transformers/tests/tokenization_tests_commons.py
View file @
7c789c33
...
...
@@ -232,6 +232,23 @@ class CommonTestCases:
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sequence
(
sequence
[:
-
2
])
def
test_always_truncate
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
length_single_sequence
=
len
(
tokenizer
.
encode
(
seq_0
))
length
=
len
(
tokenizer
.
encode
(
seq_0
,
seq_0
,
add_special_tokens
=
True
))
not_truncated
=
tokenizer
.
encode
(
seq_0
,
seq_0
,
add_special_tokens
=
True
,
max_length
=
length_single_sequence
)
truncated
=
tokenizer
.
encode
(
seq_0
,
seq_0
,
max_length
=
length_single_sequence
,
add_special_tokens
=
True
,
always_truncate
=
True
)
assert
truncated
==
not_truncated
[:
length_single_sequence
-
length
]
def
test_maximum_encoding_length_pair_input
(
self
):
tokenizer
=
self
.
get_tokenizer
()
...
...
transformers/tokenization_utils.py
View file @
7c789c33
...
...
@@ -700,6 +700,7 @@ class PreTrainedTokenizer(object):
stride
=
0
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
always_truncate
=
False
,
**
kwargs
):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...
...
@@ -721,6 +722,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
...
...
@@ -732,6 +735,7 @@ class PreTrainedTokenizer(object):
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
return_tensors
=
return_tensors
,
always_truncate
=
always_truncate
,
**
kwargs
)
return
encoded_inputs
[
"input_ids"
]
...
...
@@ -744,6 +748,7 @@ class PreTrainedTokenizer(object):
stride
=
0
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
always_truncate
=
False
,
**
kwargs
):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
...
...
@@ -764,6 +769,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
...
...
@@ -788,11 +795,12 @@ class PreTrainedTokenizer(object):
add_special_tokens
=
add_special_tokens
,
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
always_truncate
=
always_truncate
,
return_tensors
=
return_tensors
)
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
False
,
stride
=
0
,
truncate_first_sequence
=
True
,
return_tensors
=
None
):
truncate_first_sequence
=
True
,
always_truncate
=
False
,
return_tensors
=
None
):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
...
...
@@ -812,6 +820,8 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
...
...
@@ -826,8 +836,13 @@ class PreTrainedTokenizer(object):
if
max_length
:
n_added_tokens
=
self
.
num_added_tokens
(
pair
=
pair
)
if
add_special_tokens
else
0
if
pair
and
n_added_tokens
+
(
len_pair_ids
if
truncate_first_sequence
else
len_ids
)
>=
max_length
:
if
always_truncate
:
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will be truncated but one of the sequences may not be present in the resulting list of ids."
)
else
:
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will not be truncated."
)
else
:
if
n_added_tokens
+
len_ids
+
len_pair_ids
>
max_length
:
...
...
@@ -860,6 +875,10 @@ class PreTrainedTokenizer(object):
encoded_inputs
[
"input_ids"
]
=
sequence
encoded_inputs
[
"token_type_ids"
]
=
token_type_ids
if
always_truncate
and
len
(
encoded_inputs
[
"input_ids"
])
>
max_length
:
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
][:
max_length
]
encoded_inputs
[
"token_type_ids"
]
=
encoded_inputs
[
"token_type_ids"
][:
max_length
]
return
encoded_inputs
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
,
token_ids_1
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment