"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "187554366f3a2401f638c40482ccb3b3c8adadf4"
Commit 3df208c9 authored by LysandreJik's avatar LysandreJik
Browse files

Tokenizer accepts token list as well as string

parent 66ea76b8
...@@ -263,3 +263,10 @@ class CommonTestCases: ...@@ -263,3 +263,10 @@ class CommonTestCases:
assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):] assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
assert len(truncated_sequence) == len(sequence) - 2 assert len(truncated_sequence) == len(sequence) - 2
assert truncated_sequence == truncated_second_sequence assert truncated_sequence == truncated_second_sequence
def test_tokens_sent_to_encode(self):
tokenizer = self.get_tokenizer()
sequence = "Let's encode this sequence"
tokens = tokenizer.encode(sequence)
tokenizer.encode(tokens, add_special_tokens=True)
...@@ -707,14 +707,14 @@ class PreTrainedTokenizer(object): ...@@ -707,14 +707,14 @@ class PreTrainedTokenizer(object):
""" """
if text_pair is None: if text_pair is None:
if add_special_tokens: if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, str) else text
return self.add_special_tokens_single_sequence(sequence_tokens) return self.add_special_tokens_single_sequence(sequence_tokens)
else: else:
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, str) else text
return ids return ids
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] if isinstance(text, str) else text
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if isinstance(text_pair, str) else text_pair
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_sequence_pair(first_sentence_tokens, second_sentence_tokens) return self.add_special_tokens_sequence_pair(first_sentence_tokens, second_sentence_tokens)
...@@ -754,7 +754,7 @@ class PreTrainedTokenizer(object): ...@@ -754,7 +754,7 @@ class PreTrainedTokenizer(object):
information = {} information = {}
if text_pair is None: if text_pair is None:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, str) else text
if add_special_tokens: if add_special_tokens:
information = self.prepare_for_model(sequence_tokens, max_length, stride) information = self.prepare_for_model(sequence_tokens, max_length, stride)
else: else:
...@@ -766,8 +766,8 @@ class PreTrainedTokenizer(object): ...@@ -766,8 +766,8 @@ class PreTrainedTokenizer(object):
if output_mask: if output_mask:
information["mask"] = [0] * len(information["sequence"]) information["mask"] = [0] * len(information["sequence"])
else: else:
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] if isinstance(text, str) else text
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if isinstance(text_pair, str) else text_pair
if add_special_tokens: if add_special_tokens:
information = self.prepare_pair_for_model( information = self.prepare_pair_for_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment