Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7c789c33
Commit
7c789c33
authored
Sep 30, 2019
by
LysandreJik
Browse files
Always truncate argument in the encode method
parent
7af07779
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
12 deletions
+48
-12
transformers/tests/tokenization_tests_commons.py
transformers/tests/tokenization_tests_commons.py
+17
-0
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+31
-12
No files found.
transformers/tests/tokenization_tests_commons.py
View file @
7c789c33
...
@@ -232,6 +232,23 @@ class CommonTestCases:
...
@@ -232,6 +232,23 @@ class CommonTestCases:
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sequence
(
sequence
[:
-
2
])
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sequence
(
sequence
[:
-
2
])
def
test_always_truncate
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
length_single_sequence
=
len
(
tokenizer
.
encode
(
seq_0
))
length
=
len
(
tokenizer
.
encode
(
seq_0
,
seq_0
,
add_special_tokens
=
True
))
not_truncated
=
tokenizer
.
encode
(
seq_0
,
seq_0
,
add_special_tokens
=
True
,
max_length
=
length_single_sequence
)
truncated
=
tokenizer
.
encode
(
seq_0
,
seq_0
,
max_length
=
length_single_sequence
,
add_special_tokens
=
True
,
always_truncate
=
True
)
assert
truncated
==
not_truncated
[:
length_single_sequence
-
length
]
def
test_maximum_encoding_length_pair_input
(
self
):
def
test_maximum_encoding_length_pair_input
(
self
):
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
...
...
transformers/tokenization_utils.py
View file @
7c789c33
...
@@ -700,6 +700,7 @@ class PreTrainedTokenizer(object):
...
@@ -700,6 +700,7 @@ class PreTrainedTokenizer(object):
stride
=
0
,
stride
=
0
,
truncate_first_sequence
=
True
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
return_tensors
=
None
,
always_truncate
=
False
,
**
kwargs
):
**
kwargs
):
"""
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...
@@ -721,6 +722,8 @@ class PreTrainedTokenizer(object):
...
@@ -721,6 +722,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
**kwargs: passed to the `self.tokenize()` method
...
@@ -732,6 +735,7 @@ class PreTrainedTokenizer(object):
...
@@ -732,6 +735,7 @@ class PreTrainedTokenizer(object):
stride
=
stride
,
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
truncate_first_sequence
=
truncate_first_sequence
,
return_tensors
=
return_tensors
,
return_tensors
=
return_tensors
,
always_truncate
=
always_truncate
,
**
kwargs
)
**
kwargs
)
return
encoded_inputs
[
"input_ids"
]
return
encoded_inputs
[
"input_ids"
]
...
@@ -744,6 +748,7 @@ class PreTrainedTokenizer(object):
...
@@ -744,6 +748,7 @@ class PreTrainedTokenizer(object):
stride
=
0
,
stride
=
0
,
truncate_first_sequence
=
True
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
return_tensors
=
None
,
always_truncate
=
False
,
**
kwargs
):
**
kwargs
):
"""
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
...
@@ -764,6 +769,8 @@ class PreTrainedTokenizer(object):
...
@@ -764,6 +769,8 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
**kwargs: passed to the `self.tokenize()` method
...
@@ -788,11 +795,12 @@ class PreTrainedTokenizer(object):
...
@@ -788,11 +795,12 @@ class PreTrainedTokenizer(object):
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
stride
=
stride
,
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
truncate_first_sequence
=
truncate_first_sequence
,
always_truncate
=
always_truncate
,
return_tensors
=
return_tensors
)
return_tensors
=
return_tensors
)
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
False
,
stride
=
0
,
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
False
,
stride
=
0
,
truncate_first_sequence
=
True
,
return_tensors
=
None
):
truncate_first_sequence
=
True
,
always_truncate
=
False
,
return_tensors
=
None
):
"""
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
It adds special tokens, truncates
...
@@ -812,6 +820,8 @@ class PreTrainedTokenizer(object):
...
@@ -812,6 +820,8 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
or PyTorch torch.Tensor instead of a list of python integers.
...
@@ -826,8 +836,13 @@ class PreTrainedTokenizer(object):
...
@@ -826,8 +836,13 @@ class PreTrainedTokenizer(object):
if
max_length
:
if
max_length
:
n_added_tokens
=
self
.
num_added_tokens
(
pair
=
pair
)
if
add_special_tokens
else
0
n_added_tokens
=
self
.
num_added_tokens
(
pair
=
pair
)
if
add_special_tokens
else
0
if
pair
and
n_added_tokens
+
(
len_pair_ids
if
truncate_first_sequence
else
len_ids
)
>=
max_length
:
if
pair
and
n_added_tokens
+
(
len_pair_ids
if
truncate_first_sequence
else
len_ids
)
>=
max_length
:
if
always_truncate
:
logger
.
warning
(
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will be truncated but one of the sequences may not be present in the resulting list of ids."
)
else
:
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will not be truncated."
)
"This pair of sequences will not be truncated."
)
else
:
else
:
if
n_added_tokens
+
len_ids
+
len_pair_ids
>
max_length
:
if
n_added_tokens
+
len_ids
+
len_pair_ids
>
max_length
:
...
@@ -860,6 +875,10 @@ class PreTrainedTokenizer(object):
...
@@ -860,6 +875,10 @@ class PreTrainedTokenizer(object):
encoded_inputs
[
"input_ids"
]
=
sequence
encoded_inputs
[
"input_ids"
]
=
sequence
encoded_inputs
[
"token_type_ids"
]
=
token_type_ids
encoded_inputs
[
"token_type_ids"
]
=
token_type_ids
if
always_truncate
and
len
(
encoded_inputs
[
"input_ids"
])
>
max_length
:
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
][:
max_length
]
encoded_inputs
[
"token_type_ids"
]
=
encoded_inputs
[
"token_type_ids"
][:
max_length
]
return
encoded_inputs
return
encoded_inputs
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
,
token_ids_1
):
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
,
token_ids_1
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment