Unverified Commit deceb001 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2177 from mandubian/issue-2106

:zip: #2106 tokenizer.tokenize speed improvement (3-8x) by caching added_tokens in a Set
parents ed9b8481 cc013513
...@@ -231,6 +231,7 @@ class PreTrainedTokenizer(object): ...@@ -231,6 +231,7 @@ class PreTrainedTokenizer(object):
# Added tokens # Added tokens
self.added_tokens_encoder = {} self.added_tokens_encoder = {}
self.unique_added_tokens_encoder = set()
self.added_tokens_decoder = {} self.added_tokens_decoder = {}
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
...@@ -554,6 +555,7 @@ class PreTrainedTokenizer(object): ...@@ -554,6 +555,7 @@ class PreTrainedTokenizer(object):
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
self.added_tokens_decoder.update(added_tok_decoder) self.added_tokens_decoder.update(added_tok_decoder)
return len(to_add_tokens) return len(to_add_tokens)
...@@ -631,6 +633,7 @@ class PreTrainedTokenizer(object): ...@@ -631,6 +633,7 @@ class PreTrainedTokenizer(object):
return added_tokens return added_tokens
def tokenize(self, text, **kwargs): def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer. """ Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based Split in words for word-based vocabulary or sub-words for sub-word-based
...@@ -685,18 +688,17 @@ class PreTrainedTokenizer(object): ...@@ -685,18 +688,17 @@ class PreTrainedTokenizer(object):
for tok in tok_list: for tok in tok_list:
tokenized_text = [] tokenized_text = []
for sub_text in text_list: for sub_text in text_list:
if sub_text not in self.added_tokens_encoder \ if sub_text not in self.unique_added_tokens_encoder:
and sub_text not in all_special_tokens:
tokenized_text += split_on_token(tok, sub_text) tokenized_text += split_on_token(tok, sub_text)
else: else:
tokenized_text += [sub_text] tokenized_text += [sub_text]
text_list = tokenized_text text_list = tokenized_text
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \ return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) \
in self.added_tokens_encoder and token not in all_special_tokens \ if token not in self.unique_added_tokens_encoder
else [token] for token in tokenized_text))) else [token] for token in tokenized_text)))
added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens added_tokens = self.unique_added_tokens_encoder
tokenized_text = split_on_tokens(added_tokens, text) tokenized_text = split_on_tokens(added_tokens, text)
return tokenized_text return tokenized_text
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment