Commit bfe93a5a authored by thomwolf's avatar thomwolf
Browse files

fix distilbert in auto tokenizer

parent 256086bc
...@@ -94,13 +94,13 @@ class AutoTokenizer(object): ...@@ -94,13 +94,13 @@ class AutoTokenizer(object):
Examples:: Examples::
config = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache. tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache.
config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
""" """
if 'distilbert' in pretrained_model_name_or_path: if 'distilbert' in pretrained_model_name_or_path:
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment