Commit baa74326 authored by LysandreJik's avatar LysandreJik
Browse files

Stride + tests + small fixes

parent c10c7d59
...@@ -217,16 +217,18 @@ class CommonTestCases: ...@@ -217,16 +217,18 @@ class CommonTestCases:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded." seq_0 = "This is a sentence to be encoded."
stride = 2
sequence = tokenizer.encode(seq_0) sequence = tokenizer.encode(seq_0)
num_added_tokens = tokenizer.num_added_tokens() num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True) information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
truncated_sequence = information["sequence"] truncated_sequence = information["sequence"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
assert len(overflowing_tokens) == 2 assert len(overflowing_tokens) == 2 + stride
assert overflowing_tokens == sequence[-(2 + stride):]
assert len(truncated_sequence) == total_length - 2 assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2]) assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
......
...@@ -76,6 +76,5 @@ class DistilBertTokenizer(BertTokenizer): ...@@ -76,6 +76,5 @@ class DistilBertTokenizer(BertTokenizer):
| first sequence | second sequence | first sequence | second sequence
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id]
return len(self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1)) * [1] return len(self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1)) * [1]
...@@ -722,7 +722,7 @@ class PreTrainedTokenizer(object): ...@@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
logger.warning("No special tokens were added. The two sequences have been concatenated.") logger.warning("No special tokens were added. The two sequences have been concatenated.")
return first_sentence_tokens + second_sentence_tokens return first_sentence_tokens + second_sentence_tokens
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs): def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs):
""" """
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
...@@ -735,6 +735,9 @@ class PreTrainedTokenizer(object): ...@@ -735,6 +735,9 @@ class PreTrainedTokenizer(object):
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence, output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second. and 1 for the second.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
...@@ -745,13 +748,13 @@ class PreTrainedTokenizer(object): ...@@ -745,13 +748,13 @@ class PreTrainedTokenizer(object):
if add_special_tokens: if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length: if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:] information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:]
sequence_tokens = sequence_tokens[:max_length - n_added_tokens] sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
sequence = self.add_special_tokens_single_sequence(sequence_tokens) sequence = self.add_special_tokens_single_sequence(sequence_tokens)
else: else:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length: if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length:] information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
sequence_tokens = sequence_tokens[:max_length] sequence_tokens = sequence_tokens[:max_length]
sequence = sequence_tokens sequence = sequence_tokens
...@@ -788,7 +791,7 @@ class PreTrainedTokenizer(object): ...@@ -788,7 +791,7 @@ class PreTrainedTokenizer(object):
sequence = first_sentence_tokens + second_sentence_tokens sequence = first_sentence_tokens + second_sentence_tokens
if max_length: if max_length:
information["overflowing_tokens"] = sequence[max_length:] information["overflowing_tokens"] = sequence[max_length - stride:]
sequence = sequence[:max_length] sequence = sequence[:max_length]
if output_mask: if output_mask:
information["mask"] = [0] * len(sequence) information["mask"] = [0] * len(sequence)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment