Commit 14e970c2 authored by LysandreJik's avatar LysandreJik
Browse files

Tokenization encode/decode class-based sequence handling

parent fbd746bd
...@@ -105,7 +105,7 @@ class CommonTestCases: ...@@ -105,7 +105,7 @@ class CommonTestCases:
self.assertEqual(added_toks, len(new_toks)) self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks)) self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l", no_sep_cls_tokens=True) tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
self.assertGreaterEqual(len(tokens), 4) self.assertGreaterEqual(len(tokens), 4)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
...@@ -121,8 +121,7 @@ class CommonTestCases: ...@@ -121,8 +121,7 @@ class CommonTestCases:
self.assertEqual(added_toks_2, len(new_toks_2)) self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
no_sep_cls_tokens=True)
self.assertGreaterEqual(len(tokens), 6) self.assertGreaterEqual(len(tokens), 6)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
......
...@@ -166,6 +166,14 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -166,6 +166,14 @@ class BertTokenizer(PreTrainedTokenizer):
out_string = ' '.join(tokens).replace(' ##', '').strip() out_string = ' '.join(tokens).replace(' ##', '').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids):
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
def add_special_tokens_sentences_pair(self, *token_ids):
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return cls + token_ids[0] + sep + token_ids[1] + sep
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
......
...@@ -495,7 +495,7 @@ class PreTrainedTokenizer(object): ...@@ -495,7 +495,7 @@ class PreTrainedTokenizer(object):
""" """
raise NotImplementedError raise NotImplementedError
def convert_tokens_to_ids(self, tokens, **kwargs): def convert_tokens_to_ids(self, tokens):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp. a sequence of ids), using the vocabulary. (resp. a sequence of ids), using the vocabulary.
""" """
...@@ -519,31 +519,35 @@ class PreTrainedTokenizer(object): ...@@ -519,31 +519,35 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, add_special_tokens=False, *sequences):
def encode(self, *text, cls_token_at_end=False, double_sep_token=False, no_sep_cls_tokens=False):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``. Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
""" """
if len(text) == 1: if len(sequences) == 0:
return self.convert_tokens_to_ids(self.tokenize(text[0]), no_sep_cls_tokens=no_sep_cls_tokens) if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
else:
return self.convert_tokens_to_ids(self.tokenize(text))
if len(text) > 2: if len(sequences) > 1:
logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the " logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the "
"initial two.") "initial two.")
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[0])] first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[1])] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(sequences[0])]
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
n_sep_token = 2 if double_sep_token else 1
tokens = first_sentence_tokens + sep * n_sep_token + second_sentence_tokens + sep if add_special_tokens:
tokens = (tokens + cls) if cls_token_at_end else (cls + tokens) return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
else:
return first_sentence_tokens, second_sentence_tokens
return tokens def add_special_tokens_single_sentence(self, token_ids):
raise NotImplementedError
def add_special_tokens_sentences_pair(self, *token_ids):
raise NotImplementedError
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token " """ Converts a single index or a sequence of indices (integers) in a token "
...@@ -577,8 +581,7 @@ class PreTrainedTokenizer(object): ...@@ -577,8 +581,7 @@ class PreTrainedTokenizer(object):
""" """
return ' '.join(self.convert_ids_to_tokens(tokens)) return ' '.join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, cls_token_at_end=False, def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
double_sep_token=False):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces. with options to remove special tokens and clean up tokenization spaces.
......
...@@ -214,6 +214,14 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -214,6 +214,14 @@ class XLMTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids):
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
def add_special_tokens_sentences_pair(self, *token_ids):
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return cls + token_ids[0] + sep + token_ids[1] + sep
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
......
...@@ -177,6 +177,16 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -177,6 +177,16 @@ class XLNetTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids):
logger.warning("No method was defined for special tokens and single sentence streams in XLNet. "
"Returning token_ids")
return token_ids
def add_special_tokens_sentences_pair(self, *token_ids):
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return token_ids[0] + sep + token_ids[1] + sep + cls
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file """ Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory. to a directory.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment