Commit 66ea76b8 authored by LysandreJik's avatar LysandreJik
Browse files

prepare_for_model and prepare_pair_for_model methods. Added an option to...

prepare_for_model and prepare_pair_for_model methods. Added an option to select which sequence will be truncated.
parent 60414f31
...@@ -237,16 +237,29 @@ class CommonTestCases: ...@@ -237,16 +237,29 @@ class CommonTestCases:
seq_0 = "This is a sentence to be encoded." seq_0 = "This is a sentence to be encoded."
seq_1 = "This is another sentence to be encoded." seq_1 = "This is another sentence to be encoded."
stride = 2
sequence_0_no_special_tokens = tokenizer.encode(seq_0)
sequence_1_no_special_tokens = tokenizer.encode(seq_1)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair( truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
tokenizer.encode(seq_0), tokenizer.encode(seq_0),
tokenizer.encode(seq_1)[:-2] tokenizer.encode(seq_1)[:-2]
) )
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True)
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride)
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride,
truncate_second_sequence_first=False)
truncated_sequence = information["sequence"] truncated_sequence = information["sequence"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
assert len(overflowing_tokens) == 2 + stride
assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):]
assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
assert len(truncated_sequence) == len(sequence) - 2 assert len(truncated_sequence) == len(sequence) - 2
assert truncated_sequence == truncated_second_sequence assert truncated_sequence == truncated_second_sequence
...@@ -722,7 +722,15 @@ class PreTrainedTokenizer(object): ...@@ -722,7 +722,15 @@ class PreTrainedTokenizer(object):
logger.warning("No special tokens were added. The two sequences have been concatenated.") logger.warning("No special tokens were added. The two sequences have been concatenated.")
return first_sentence_tokens + second_sentence_tokens return first_sentence_tokens + second_sentence_tokens
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs): def encode_plus(self,
text,
text_pair=None,
add_special_tokens=False,
output_mask=False,
max_length=None,
stride=0,
truncate_second_sequence_first=True,
**kwargs):
""" """
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
...@@ -738,54 +746,40 @@ class PreTrainedTokenizer(object): ...@@ -738,54 +746,40 @@ class PreTrainedTokenizer(object):
If there are overflowing tokens, those will be added to the returned dictionary If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens. from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_second_sequence_first: if there is a specified max_length, this flag will choose which sequence
will be truncated.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
information = {} information = {}
if text_pair is None: if text_pair is None:
n_added_tokens = self.num_added_tokens() sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if add_special_tokens: if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) information = self.prepare_for_model(sequence_tokens, max_length, stride)
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:]
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
else: else:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length: if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - stride:] information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
sequence_tokens = sequence_tokens[:max_length] sequence_tokens = sequence_tokens[:max_length]
sequence = sequence_tokens information["sequence"] = sequence_tokens
if output_mask: if output_mask:
information["mask"] = [0] * len(sequence) information["mask"] = [0] * len(information["sequence"])
information["sequence"] = sequence
else: else:
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
f_len, s_len = len(first_sentence_tokens), len(second_sentence_tokens)
n_added_tokens = self.num_added_tokens(pair=True)
if add_special_tokens: if add_special_tokens:
if max_length: information = self.prepare_pair_for_model(
if len(first_sentence_tokens) + n_added_tokens >= max_length:
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
else:
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
sequence = self.add_special_tokens_sequence_pair(
first_sentence_tokens, first_sentence_tokens,
second_sentence_tokens second_sentence_tokens,
max_length,
truncate_second_sequence_first,
stride
) )
if output_mask: if output_mask:
information["mask"] = self.create_mask_from_sequences(text, text_pair) information["mask"] = self.create_mask_from_sequences(text, text_pair)
information["sequence"] = sequence
else: else:
logger.warning("No special tokens were added. The two sequences have been concatenated.") logger.warning("No special tokens were added. The two sequences have been concatenated.")
sequence = first_sentence_tokens + second_sentence_tokens sequence = first_sentence_tokens + second_sentence_tokens
...@@ -800,6 +794,39 @@ class PreTrainedTokenizer(object): ...@@ -800,6 +794,39 @@ class PreTrainedTokenizer(object):
return information return information
def prepare_for_model(self, ids, max_length=None, stride=0):
information = {}
n_added_tokens = self.num_added_tokens()
if max_length:
information["overflowing_tokens"] = ids[max_length - n_added_tokens - stride:]
ids = ids[:max_length - n_added_tokens]
information["sequence"] = self.add_special_tokens_single_sequence(ids)
return information
def prepare_pair_for_model(self, ids_0, ids_1, max_length=None, truncate_second_sequence_first=True, stride=0):
f_len, s_len = len(ids_0), len(ids_1)
n_added_tokens = self.num_added_tokens(pair=True)
information = {}
if max_length:
if len(ids_0) + n_added_tokens >= max_length:
logger.warning(
"The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
else:
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
if truncate_second_sequence_first:
information["overflowing_tokens"] = ids_1[max_length - f_len - n_added_tokens - stride:]
ids_1 = ids_1[:max_length - f_len - n_added_tokens]
else:
information["overflowing_tokens"] = ids_0[max_length - s_len - n_added_tokens - stride:]
ids_0 = ids_0[:max_length - s_len - n_added_tokens]
sequence = self.add_special_tokens_sequence_pair(ids_0, ids_1)
information["sequence"] = sequence
return information
def create_mask_from_sequences(self, sequence_0, sequence_1): def create_mask_from_sequences(self, sequence_0, sequence_1):
logger.warning("This tokenizer does not make use of special tokens.") logger.warning("This tokenizer does not make use of special tokens.")
return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1)) return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment