Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
651bfb7a
Commit
651bfb7a
authored
Sep 30, 2019
by
LysandreJik
Browse files
always_truncate by default
parent
5ed50a93
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
39 deletions
+5
-39
transformers/tests/tokenization_tests_commons.py
transformers/tests/tokenization_tests_commons.py
+0
-18
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+5
-21
No files found.
transformers/tests/tokenization_tests_commons.py
View file @
651bfb7a
...
...
@@ -232,23 +232,6 @@ class CommonTestCases:
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sequence
(
sequence
[:
-
2
])
def
test_always_truncate
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
length_single_sequence
=
len
(
tokenizer
.
encode
(
seq_0
))
length
=
len
(
tokenizer
.
encode
(
seq_0
,
seq_0
,
add_special_tokens
=
True
))
not_truncated
=
tokenizer
.
encode
(
seq_0
,
seq_0
,
add_special_tokens
=
True
,
max_length
=
length_single_sequence
)
truncated
=
tokenizer
.
encode
(
seq_0
,
seq_0
,
max_length
=
length_single_sequence
,
add_special_tokens
=
True
,
always_truncate
=
True
)
assert
truncated
==
not_truncated
[:
length_single_sequence
-
length
]
def
test_maximum_encoding_length_pair_input
(
self
):
tokenizer
=
self
.
get_tokenizer
()
...
...
@@ -329,7 +312,6 @@ class CommonTestCases:
sequence_ids_orig
=
encoded_sequence_dict
[
"sequence_ids"
]
sequence_ids
=
tokenizer
.
get_sequence_ids
(
encoded_sequence_w_special
,
special_tokens_present
=
True
)
assert
len
(
sequence_ids
)
==
len
(
encoded_sequence_w_special
)
print
(
sequence_ids_orig
,
sequence_ids
)
assert
sequence_ids_orig
==
sequence_ids
...
...
transformers/tokenization_utils.py
View file @
651bfb7a
...
...
@@ -700,7 +700,6 @@ class PreTrainedTokenizer(object):
stride
=
0
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
always_truncate
=
False
,
**
kwargs
):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...
...
@@ -722,8 +721,6 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
...
...
@@ -735,7 +732,6 @@ class PreTrainedTokenizer(object):
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
return_tensors
=
return_tensors
,
always_truncate
=
always_truncate
,
**
kwargs
)
return
encoded_inputs
[
"input_ids"
]
...
...
@@ -748,7 +744,6 @@ class PreTrainedTokenizer(object):
stride
=
0
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
always_truncate
=
False
,
**
kwargs
):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
...
...
@@ -769,8 +764,6 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
...
...
@@ -795,12 +788,10 @@ class PreTrainedTokenizer(object):
add_special_tokens
=
add_special_tokens
,
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
always_truncate
=
always_truncate
,
return_tensors
=
return_tensors
)
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
False
,
stride
=
0
,
truncate_first_sequence
=
True
,
always_truncate
=
False
,
return_tensors
=
None
):
truncate_first_sequence
=
True
,
return_tensors
=
None
):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
...
...
@@ -820,8 +811,6 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
...
...
@@ -850,14 +839,9 @@ class PreTrainedTokenizer(object):
if
max_length
:
n_added_tokens
=
self
.
num_added_tokens
(
pair
=
pair
)
if
add_special_tokens
else
0
if
pair
and
n_added_tokens
+
(
len_pair_ids
if
truncate_first_sequence
else
len_ids
)
>=
max_length
:
if
always_truncate
:
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will be truncated but one of the sequences may not be present in the resulting list of ids."
)
else
:
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will not be truncated."
)
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will be truncated with no regard to the special tokens"
)
else
:
if
n_added_tokens
+
len_ids
+
len_pair_ids
>
max_length
:
if
truncate_first_sequence
or
not
pair
:
...
...
@@ -890,7 +874,7 @@ class PreTrainedTokenizer(object):
encoded_inputs
[
"input_ids"
]
=
sequence
encoded_inputs
[
"token_type_ids"
]
=
token_type_ids
if
always_truncate
and
len
(
encoded_inputs
[
"input_ids"
])
>
max_length
:
if
max_length
and
len
(
encoded_inputs
[
"input_ids"
])
>
max_length
:
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
][:
max_length
]
encoded_inputs
[
"token_type_ids"
]
=
encoded_inputs
[
"token_type_ids"
][:
max_length
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment