"...git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "e85f029bd4e4b1bdf3e679893fb6447e4d6b2c79"
Commit 197d74f9 authored by Joe Davison's avatar Joe Davison
Browse files

Add get_vocab method to PretrainedTokenizer

parent ea8eba35
...@@ -114,6 +114,11 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -114,6 +114,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.sp_model) return len(self.sp_model)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
......
...@@ -195,6 +195,9 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -195,6 +195,9 @@ class BertTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.vocab) return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
......
...@@ -147,6 +147,9 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -147,6 +147,9 @@ class CTRLTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token): def bpe(self, token):
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
......
...@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token): def bpe(self, token):
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]
......
...@@ -125,6 +125,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -125,6 +125,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + "</w>",) word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache: if token in self.cache:
......
...@@ -119,6 +119,11 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -119,6 +119,11 @@ class T5Tokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return self.sp_model.get_piece_size() + self._extra_ids return self.sp_model.get_piece_size() + self._extra_ids
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
......
...@@ -273,6 +273,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -273,6 +273,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.idx2sym) return len(self.idx2sym)
def get_vocab(self):
return dict(self.sym2idx, **self.added_tokens_encoder)
def _tokenize(self, line, add_eos=False, add_double_eos=False): def _tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip() line = line.strip()
# convert to lower case # convert to lower case
......
...@@ -286,6 +286,10 @@ class PreTrainedTokenizer(object): ...@@ -286,6 +286,10 @@ class PreTrainedTokenizer(object):
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """ """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
return self.convert_tokens_to_ids(self.additional_special_tokens) return self.convert_tokens_to_ids(self.additional_special_tokens)
def get_vocab(self):
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
raise NotImplementedError()
def __init__(self, max_len=None, **kwargs): def __init__(self, max_len=None, **kwargs):
self._bos_token = None self._bos_token = None
self._eos_token = None self._eos_token = None
......
...@@ -662,6 +662,9 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -662,6 +662,9 @@ class XLMTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + "</w>",) word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache: if token in self.cache:
......
...@@ -190,6 +190,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -190,6 +190,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.sp_model) + len(self.fairseq_tokens_to_ids) return len(self.sp_model) + len(self.fairseq_tokens_to_ids)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text): def _tokenize(self, text):
return self.sp_model.EncodeAsPieces(text) return self.sp_model.EncodeAsPieces(text)
......
...@@ -114,6 +114,11 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -114,6 +114,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.sp_model) return len(self.sp_model)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
......
...@@ -542,3 +542,23 @@ class TokenizerTesterMixin: ...@@ -542,3 +542,23 @@ class TokenizerTesterMixin:
print(new_tokenizer.init_kwargs) print(new_tokenizer.init_kwargs)
assert tokenizer.init_kwargs["random_argument"] is True assert tokenizer.init_kwargs["random_argument"] is True
assert new_tokenizer.init_kwargs["random_argument"] is False assert new_tokenizer.init_kwargs["random_argument"] is False
def test_get_vocab(self):
tokenizer = self.get_tokenizer()
vocab = tokenizer.get_vocab()
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer))
for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
tokenizer.add_tokens(["asdfasdfasdfasdf"])
vocab = tokenizer.get_vocab()
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer))
for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment