Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
651bfb7a
"sims/nic/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "56775bc1685edb3c30c488bf1bd95db03ef38c72"
Commit
651bfb7a
authored
Sep 30, 2019
by
LysandreJik
Browse files
always_truncate by default
parent
5ed50a93
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
39 deletions
+5
-39
transformers/tests/tokenization_tests_commons.py
transformers/tests/tokenization_tests_commons.py
+0
-18
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+5
-21
No files found.
transformers/tests/tokenization_tests_commons.py
View file @
651bfb7a
...
@@ -232,23 +232,6 @@ class CommonTestCases:
...
@@ -232,23 +232,6 @@ class CommonTestCases:
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
len
(
truncated_sequence
)
==
total_length
-
2
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sequence
(
sequence
[:
-
2
])
assert
truncated_sequence
==
tokenizer
.
add_special_tokens_single_sequence
(
sequence
[:
-
2
])
def
test_always_truncate
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
length_single_sequence
=
len
(
tokenizer
.
encode
(
seq_0
))
length
=
len
(
tokenizer
.
encode
(
seq_0
,
seq_0
,
add_special_tokens
=
True
))
not_truncated
=
tokenizer
.
encode
(
seq_0
,
seq_0
,
add_special_tokens
=
True
,
max_length
=
length_single_sequence
)
truncated
=
tokenizer
.
encode
(
seq_0
,
seq_0
,
max_length
=
length_single_sequence
,
add_special_tokens
=
True
,
always_truncate
=
True
)
assert
truncated
==
not_truncated
[:
length_single_sequence
-
length
]
def
test_maximum_encoding_length_pair_input
(
self
):
def
test_maximum_encoding_length_pair_input
(
self
):
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
...
@@ -329,7 +312,6 @@ class CommonTestCases:
...
@@ -329,7 +312,6 @@ class CommonTestCases:
sequence_ids_orig
=
encoded_sequence_dict
[
"sequence_ids"
]
sequence_ids_orig
=
encoded_sequence_dict
[
"sequence_ids"
]
sequence_ids
=
tokenizer
.
get_sequence_ids
(
encoded_sequence_w_special
,
special_tokens_present
=
True
)
sequence_ids
=
tokenizer
.
get_sequence_ids
(
encoded_sequence_w_special
,
special_tokens_present
=
True
)
assert
len
(
sequence_ids
)
==
len
(
encoded_sequence_w_special
)
assert
len
(
sequence_ids
)
==
len
(
encoded_sequence_w_special
)
print
(
sequence_ids_orig
,
sequence_ids
)
assert
sequence_ids_orig
==
sequence_ids
assert
sequence_ids_orig
==
sequence_ids
...
...
transformers/tokenization_utils.py
View file @
651bfb7a
...
@@ -700,7 +700,6 @@ class PreTrainedTokenizer(object):
...
@@ -700,7 +700,6 @@ class PreTrainedTokenizer(object):
stride
=
0
,
stride
=
0
,
truncate_first_sequence
=
True
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
return_tensors
=
None
,
always_truncate
=
False
,
**
kwargs
):
**
kwargs
):
"""
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...
@@ -722,8 +721,6 @@ class PreTrainedTokenizer(object):
...
@@ -722,8 +721,6 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
**kwargs: passed to the `self.tokenize()` method
...
@@ -735,7 +732,6 @@ class PreTrainedTokenizer(object):
...
@@ -735,7 +732,6 @@ class PreTrainedTokenizer(object):
stride
=
stride
,
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
truncate_first_sequence
=
truncate_first_sequence
,
return_tensors
=
return_tensors
,
return_tensors
=
return_tensors
,
always_truncate
=
always_truncate
,
**
kwargs
)
**
kwargs
)
return
encoded_inputs
[
"input_ids"
]
return
encoded_inputs
[
"input_ids"
]
...
@@ -748,7 +744,6 @@ class PreTrainedTokenizer(object):
...
@@ -748,7 +744,6 @@ class PreTrainedTokenizer(object):
stride
=
0
,
stride
=
0
,
truncate_first_sequence
=
True
,
truncate_first_sequence
=
True
,
return_tensors
=
None
,
return_tensors
=
None
,
always_truncate
=
False
,
**
kwargs
):
**
kwargs
):
"""
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
...
@@ -769,8 +764,6 @@ class PreTrainedTokenizer(object):
...
@@ -769,8 +764,6 @@ class PreTrainedTokenizer(object):
from the main sequence returned. The value of this argument defined the number of additional tokens.
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
will be truncated.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
**kwargs: passed to the `self.tokenize()` method
...
@@ -795,12 +788,10 @@ class PreTrainedTokenizer(object):
...
@@ -795,12 +788,10 @@ class PreTrainedTokenizer(object):
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
stride
=
stride
,
stride
=
stride
,
truncate_first_sequence
=
truncate_first_sequence
,
truncate_first_sequence
=
truncate_first_sequence
,
always_truncate
=
always_truncate
,
return_tensors
=
return_tensors
)
return_tensors
=
return_tensors
)
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
False
,
stride
=
0
,
def
prepare_for_model
(
self
,
ids
,
pair_ids
=
None
,
max_length
=
None
,
add_special_tokens
=
False
,
stride
=
0
,
truncate_first_sequence
=
True
,
always_truncate
=
False
,
return_tensors
=
None
):
truncate_first_sequence
=
True
,
return_tensors
=
None
):
"""
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
It adds special tokens, truncates
...
@@ -820,8 +811,6 @@ class PreTrainedTokenizer(object):
...
@@ -820,8 +811,6 @@ class PreTrainedTokenizer(object):
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
sequences may be lost in the process.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
or PyTorch torch.Tensor instead of a list of python integers.
...
@@ -850,14 +839,9 @@ class PreTrainedTokenizer(object):
...
@@ -850,14 +839,9 @@ class PreTrainedTokenizer(object):
if
max_length
:
if
max_length
:
n_added_tokens
=
self
.
num_added_tokens
(
pair
=
pair
)
if
add_special_tokens
else
0
n_added_tokens
=
self
.
num_added_tokens
(
pair
=
pair
)
if
add_special_tokens
else
0
if
pair
and
n_added_tokens
+
(
len_pair_ids
if
truncate_first_sequence
else
len_ids
)
>=
max_length
:
if
pair
and
n_added_tokens
+
(
len_pair_ids
if
truncate_first_sequence
else
len_ids
)
>=
max_length
:
if
always_truncate
:
logger
.
warning
(
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will be truncated with no regard to the special tokens"
)
"This pair of sequences will be truncated but one of the sequences may not be present in the resulting list of ids."
)
else
:
logger
.
warning
(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will not be truncated."
)
else
:
else
:
if
n_added_tokens
+
len_ids
+
len_pair_ids
>
max_length
:
if
n_added_tokens
+
len_ids
+
len_pair_ids
>
max_length
:
if
truncate_first_sequence
or
not
pair
:
if
truncate_first_sequence
or
not
pair
:
...
@@ -890,7 +874,7 @@ class PreTrainedTokenizer(object):
...
@@ -890,7 +874,7 @@ class PreTrainedTokenizer(object):
encoded_inputs
[
"input_ids"
]
=
sequence
encoded_inputs
[
"input_ids"
]
=
sequence
encoded_inputs
[
"token_type_ids"
]
=
token_type_ids
encoded_inputs
[
"token_type_ids"
]
=
token_type_ids
if
always_truncate
and
len
(
encoded_inputs
[
"input_ids"
])
>
max_length
:
if
max_length
and
len
(
encoded_inputs
[
"input_ids"
])
>
max_length
:
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
][:
max_length
]
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
][:
max_length
]
encoded_inputs
[
"token_type_ids"
]
=
encoded_inputs
[
"token_type_ids"
][:
max_length
]
encoded_inputs
[
"token_type_ids"
]
=
encoded_inputs
[
"token_type_ids"
][:
max_length
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment