Commit baa74326 authored by LysandreJik's avatar LysandreJik
Browse files

Stride + tests + small fixes

parent c10c7d59
......@@ -217,16 +217,18 @@ class CommonTestCases:
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
stride = 2
sequence = tokenizer.encode(seq_0)
num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True)
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
truncated_sequence = information["sequence"]
overflowing_tokens = information["overflowing_tokens"]
assert len(overflowing_tokens) == 2
assert len(overflowing_tokens) == 2 + stride
assert overflowing_tokens == sequence[-(2 + stride):]
assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
......
......@@ -76,6 +76,5 @@ class DistilBertTokenizer(BertTokenizer):
| first sequence | second sequence
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return len(self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1)) * [1]
......@@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
logger.warning("No special tokens were added. The two sequences have been concatenated.")
return first_sentence_tokens + second_sentence_tokens
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs):
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs):
"""
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
......@@ -735,6 +735,9 @@ class PreTrainedTokenizer(object):
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
**kwargs: passed to the `self.tokenize()` method
"""
......@@ -745,13 +748,13 @@ class PreTrainedTokenizer(object):
if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:]
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:]
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
else:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length:]
information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
sequence_tokens = sequence_tokens[:max_length]
sequence = sequence_tokens
......@@ -788,7 +791,7 @@ class PreTrainedTokenizer(object):
sequence = first_sentence_tokens + second_sentence_tokens
if max_length:
information["overflowing_tokens"] = sequence[max_length:]
information["overflowing_tokens"] = sequence[max_length - stride:]
sequence = sequence[:max_length]
if output_mask:
information["mask"] = [0] * len(sequence)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment