Commit 27b0f86d authored by thomwolf's avatar thomwolf
Browse files

clean up pretrained

parent 57e54ec0
...@@ -152,11 +152,13 @@ class PreTrainedTokenizer(object): ...@@ -152,11 +152,13 @@ class PreTrainedTokenizer(object):
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" """
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files. Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
Download and cache the vocabulary files if needed. Download and cache the vocabulary files if needed.
""" """
cache_dir = kwargs.pop('cache_dir', None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
...@@ -308,7 +310,8 @@ class PreTrainedTokenizer(object): ...@@ -308,7 +310,8 @@ class PreTrainedTokenizer(object):
to_add_tokens = [] to_add_tokens = []
for token in new_tokens: for token in new_tokens:
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token): if token != self.unk_token and \
self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
to_add_tokens.append(token) to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token) logger.info("Adding %s to the vocabulary", token)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment