"test/vscode:/vscode.git/clone" did not exist on "20b8d2306c3d9501b4d47399b6e56d63d30d53a3"
Commit 3d87991f authored by LysandreJik's avatar LysandreJik
Browse files

Fixed error with encoding

parent 634a3172
...@@ -81,11 +81,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -81,11 +81,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build")
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == [0] + text + [2] assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == [0] + text + [2, 2] + text_2 + [2] assert encoded_pair == encoded_pair_from_decode
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -519,24 +519,19 @@ class PreTrainedTokenizer(object): ...@@ -519,24 +519,19 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, add_special_tokens=False, *sequences): def encode(self, text, text_pair=None, add_special_tokens=False):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``. Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
""" """
if text_pair is None:
if len(sequences) == 0:
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text))) return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
else: else:
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text))
if len(sequences) > 1:
logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the "
"initial two.")
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)] first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(sequences[0])] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)]
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment