Commit 78ef1a99 authored by thomwolf's avatar thomwolf
Browse files

fixes

parent 6c1d0bc0
...@@ -336,7 +336,6 @@ def convert_examples_to_features( ...@@ -336,7 +336,6 @@ def convert_examples_to_features(
text_b, text_b,
add_special_tokens=True, add_special_tokens=True,
max_length=max_length, max_length=max_length,
truncate_both_sequences=True
) )
if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0: if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0:
logger.info('Attention! you are cropping tokens (swag task is ok). ' logger.info('Attention! you are cropping tokens (swag task is ok). '
......
...@@ -86,7 +86,6 @@ def glue_convert_examples_to_features(examples, tokenizer, ...@@ -86,7 +86,6 @@ def glue_convert_examples_to_features(examples, tokenizer,
example.text_b, example.text_b,
add_special_tokens=True, add_special_tokens=True,
max_length=max_length, max_length=max_length,
truncate_first_sequence=True # We're truncating the first sequence in priority
) )
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
......
...@@ -249,10 +249,10 @@ class CommonTestCases: ...@@ -249,10 +249,10 @@ class CommonTestCases:
) )
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True, information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride, truncate_first_sequence=False) stride=stride, truncation_strategy='only_second')
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride, add_special_tokens=True, stride=stride,
truncate_first_sequence=True) truncation_strategy='only_first')
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
......
...@@ -692,8 +692,7 @@ class PreTrainedTokenizer(object): ...@@ -692,8 +692,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=False, add_special_tokens=False,
max_length=None, max_length=None,
stride=0, stride=0,
truncate_first_sequence=True, truncation_strategy='longest_first',
truncate_both_sequences=False,
return_tensors=None, return_tensors=None,
**kwargs): **kwargs):
""" """
...@@ -714,8 +713,12 @@ class PreTrainedTokenizer(object): ...@@ -714,8 +713,12 @@ class PreTrainedTokenizer(object):
If there are overflowing tokens, those will be added to the returned dictionary If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens. from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncation_strategy: string selected in the following options:
will be truncated. - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
...@@ -725,8 +728,7 @@ class PreTrainedTokenizer(object): ...@@ -725,8 +728,7 @@ class PreTrainedTokenizer(object):
max_length=max_length, max_length=max_length,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence, truncation_strategy=truncation_strategy,
truncate_both_sequences=truncate_both_sequences,
return_tensors=return_tensors, return_tensors=return_tensors,
**kwargs) **kwargs)
...@@ -738,8 +740,7 @@ class PreTrainedTokenizer(object): ...@@ -738,8 +740,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=False, add_special_tokens=False,
max_length=None, max_length=None,
stride=0, stride=0,
truncate_first_sequence=True, truncation_strategy='longest_first',
truncate_both_sequences=False,
return_tensors=None, return_tensors=None,
**kwargs): **kwargs):
""" """
...@@ -759,8 +760,12 @@ class PreTrainedTokenizer(object): ...@@ -759,8 +760,12 @@ class PreTrainedTokenizer(object):
If there are overflowing tokens, those will be added to the returned dictionary If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens. from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncation_strategy: string selected in the following options:
will be truncated. - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
...@@ -784,8 +789,7 @@ class PreTrainedTokenizer(object): ...@@ -784,8 +789,7 @@ class PreTrainedTokenizer(object):
max_length=max_length, max_length=max_length,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence, truncation_strategy=truncation_strategy,
truncate_both_sequences=truncate_both_sequences,
return_tensors=return_tensors) return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
...@@ -812,9 +816,6 @@ class PreTrainedTokenizer(object): ...@@ -812,9 +816,6 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence - 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence - 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
...@@ -844,7 +845,8 @@ class PreTrainedTokenizer(object): ...@@ -844,7 +845,8 @@ class PreTrainedTokenizer(object):
if max_length and total_len > max_length: if max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids, ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
num_tokens_to_remove=total_len-max_length, num_tokens_to_remove=total_len-max_length,
truncation_strategy=truncation_strategy) truncation_strategy=truncation_strategy,
stride=stride)
encoded_inputs["overflowing_tokens"] = overflowing_tokens encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length encoded_inputs["num_truncated_tokens"] = total_len - max_length
...@@ -875,7 +877,7 @@ class PreTrainedTokenizer(object): ...@@ -875,7 +877,7 @@ class PreTrainedTokenizer(object):
return encoded_inputs return encoded_inputs
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first'): def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
"""Truncates a sequence pair in place to the maximum length. """Truncates a sequence pair in place to the maximum length.
truncation_strategy: string selected in the following options: truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
...@@ -892,17 +894,22 @@ class PreTrainedTokenizer(object): ...@@ -892,17 +894,22 @@ class PreTrainedTokenizer(object):
overflowing_tokens = [] overflowing_tokens = []
for _ in range(num_tokens_to_remove): for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids): if pair_ids is None or len(ids) > len(pair_ids):
overflowing_tokens.append(ids[-1]) overflowing_tokens = [ids[-1]] + overflowing_tokens
ids = ids[:-1] ids = ids[:-1]
else: else:
pair_ids = pair_ids[:-1] pair_ids = pair_ids[:-1]
window_len = min(len(ids), stride)
if window_len > 0:
overflowing_tokens = ids[-window_len:] + overflowing_tokens
elif truncation_strategy == 'only_first': elif truncation_strategy == 'only_first':
assert len(ids) > num_tokens_to_remove assert len(ids) > num_tokens_to_remove
overflowing_tokens = ids[-num_tokens_to_remove:] window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove] ids = ids[:-num_tokens_to_remove]
elif truncation_strategy == 'only_second': elif truncation_strategy == 'only_second':
assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
overflowing_tokens = pair_ids[-num_tokens_to_remove:] window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove] pair_ids = pair_ids[:-num_tokens_to_remove]
elif truncation_strategy == 'do_not_truncate': elif truncation_strategy == 'do_not_truncate':
raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.") raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment