Commit 78ef1a99 authored by thomwolf's avatar thomwolf
Browse files

fixes

parent 6c1d0bc0
......@@ -336,7 +336,6 @@ def convert_examples_to_features(
text_b,
add_special_tokens=True,
max_length=max_length,
truncate_both_sequences=True
)
if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0:
logger.info('Attention! you are cropping tokens (swag task is ok). '
......
......@@ -86,7 +86,6 @@ def glue_convert_examples_to_features(examples, tokenizer,
example.text_b,
add_special_tokens=True,
max_length=max_length,
truncate_first_sequence=True # We're truncating the first sequence in priority
)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
......
......@@ -249,10 +249,10 @@ class CommonTestCases:
)
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride, truncate_first_sequence=False)
stride=stride, truncation_strategy='only_second')
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride,
truncate_first_sequence=True)
truncation_strategy='only_first')
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
......
......@@ -692,8 +692,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
truncate_both_sequences=False,
truncation_strategy='longest_first',
return_tensors=None,
**kwargs):
"""
......@@ -714,8 +713,12 @@ class PreTrainedTokenizer(object):
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
......@@ -725,8 +728,7 @@ class PreTrainedTokenizer(object):
max_length=max_length,
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
truncate_both_sequences=truncate_both_sequences,
truncation_strategy=truncation_strategy,
return_tensors=return_tensors,
**kwargs)
......@@ -738,8 +740,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
truncate_both_sequences=False,
truncation_strategy='longest_first',
return_tensors=None,
**kwargs):
"""
......@@ -759,8 +760,12 @@ class PreTrainedTokenizer(object):
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
......@@ -784,8 +789,7 @@ class PreTrainedTokenizer(object):
max_length=max_length,
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
truncate_both_sequences=truncate_both_sequences,
truncation_strategy=truncation_strategy,
return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
......@@ -812,9 +816,6 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
......@@ -844,7 +845,8 @@ class PreTrainedTokenizer(object):
if max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
num_tokens_to_remove=total_len-max_length,
truncation_strategy=truncation_strategy)
truncation_strategy=truncation_strategy,
stride=stride)
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
......@@ -875,7 +877,7 @@ class PreTrainedTokenizer(object):
return encoded_inputs
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first'):
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
"""Truncates a sequence pair in place to the maximum length.
truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
......@@ -892,17 +894,22 @@ class PreTrainedTokenizer(object):
overflowing_tokens = []
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
overflowing_tokens.append(ids[-1])
overflowing_tokens = [ids[-1]] + overflowing_tokens
ids = ids[:-1]
else:
pair_ids = pair_ids[:-1]
window_len = min(len(ids), stride)
if window_len > 0:
overflowing_tokens = ids[-window_len:] + overflowing_tokens
elif truncation_strategy == 'only_first':
assert len(ids) > num_tokens_to_remove
overflowing_tokens = ids[-num_tokens_to_remove:]
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
elif truncation_strategy == 'only_second':
assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
overflowing_tokens = pair_ids[-num_tokens_to_remove:]
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
elif truncation_strategy == 'do_not_truncate':
raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment