Commit dcc9bb32 authored by LysandreJik's avatar LysandreJik
Browse files

Modified encode to return only lists. Added a more complete encode_plus method

parent af23b626
...@@ -535,7 +535,7 @@ class PreTrainedTokenizer(object): ...@@ -535,7 +535,7 @@ class PreTrainedTokenizer(object):
""" """
if pair: if pair:
initial_tokens_len = sum([len(encoded) for encoded in self.encode("This is a sequence", "This is another")]) initial_tokens_len = len(self.encode("This is a sequence") + self.encode("This is another"))
final_tokens = self.encode("This is a sequence", "This is another", add_special_tokens=True) final_tokens = self.encode("This is a sequence", "This is another", add_special_tokens=True)
# In some models (e.g. GPT-2), there is no sequence pair encoding. # In some models (e.g. GPT-2), there is no sequence pair encoding.
...@@ -693,10 +693,39 @@ class PreTrainedTokenizer(object): ...@@ -693,10 +693,39 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs): def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args:
text: The first sequence to be encoded.
text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
"""
if text_pair is None:
if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
return self.add_special_tokens_single_sentence(sequence_tokens)
else:
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
return ids
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
else:
logger.warning("No special tokens were added. The two sequences have been concatenated.")
return first_sentence_tokens + second_sentence_tokens
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args: Args:
...@@ -709,6 +738,69 @@ class PreTrainedTokenizer(object): ...@@ -709,6 +738,69 @@ class PreTrainedTokenizer(object):
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
information = {}
if text_pair is None:
n_added_tokens = self.num_added_tokens()
if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:]
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
sequence = self.add_special_tokens_single_sentence(sequence_tokens)
else:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length:]
sequence_tokens = sequence_tokens[:max_length]
sequence = sequence_tokens
if output_mask:
information["mask"] = [0] * len(sequence)
information["sequence"] = sequence
else:
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
f_len, s_len = len(first_sentence_tokens), len(second_sentence_tokens)
n_added_tokens = self.num_added_tokens(pair=True)
if add_special_tokens:
if max_length:
if len(first_sentence_tokens) + n_added_tokens >= max_length:
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
else:
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
encoded_sequence = self.add_special_tokens_sentences_pair(
first_sentence_tokens,
second_sentence_tokens,
output_mask
)
if output_mask:
sequence, information["mask"] = encoded_sequence
else:
sequence = encoded_sequence
information["sequence"] = sequence
else:
logger.warning("No special tokens were added. The two sequences have been concatenated.")
sequence = first_sentence_tokens + second_sentence_tokens
if max_length:
information["overflowing_tokens"] = sequence[max_length:]
sequence = sequence[:max_length]
if output_mask:
information["mask"] = [0] * len(sequence)
information["sequence"] = sequence
return information
if text_pair is None: if text_pair is None:
if add_special_tokens: if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
...@@ -725,12 +817,17 @@ class PreTrainedTokenizer(object): ...@@ -725,12 +817,17 @@ class PreTrainedTokenizer(object):
if add_special_tokens: if add_special_tokens:
if max_length: if max_length:
if len(first_sentence_tokens) + self.num_added_tokens(pair=True) >= max_length: if len(first_sentence_tokens) + self.num_added_tokens(pair=True) >= max_length:
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.") logger.warning(
"The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
else: else:
if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens(pair=True) > max_length: if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens(
second_sentence_tokens = second_sentence_tokens[:max_length - len(first_sentence_tokens) - self.num_added_tokens(pair=True)] pair=True) > max_length:
second_sentence_tokens = second_sentence_tokens[
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask) :max_length - len(first_sentence_tokens) - self.num_added_tokens(
pair=True)]
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens,
output_mask)
else: else:
if max_length: if max_length:
first_sentence_tokens = first_sentence_tokens[:max_length] first_sentence_tokens = first_sentence_tokens[:max_length]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment