Commit d6f06c03 authored by thomwolf's avatar thomwolf
Browse files

fixed loading pre-trained tokenizer from directory

parent 532a81d3
...@@ -478,7 +478,7 @@ class PreTrainedBertModel(nn.Module): ...@@ -478,7 +478,7 @@ class PreTrainedBertModel(nn.Module):
"associated to this path or url.".format( "associated to this path or url.".format(
pretrained_model_name, pretrained_model_name,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
pretrained_model_name)) archive_file))
return None return None
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file)) logger.info("loading archive file {}".format(archive_file))
......
...@@ -39,6 +39,7 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { ...@@ -39,6 +39,7 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
} }
VOCAB_NAME = 'vocab.txt'
def load_vocab(vocab_file): def load_vocab(vocab_file):
...@@ -100,7 +101,7 @@ class BertTokenizer(object): ...@@ -100,7 +101,7 @@ class BertTokenizer(object):
return tokens return tokens
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, do_lower_case=True): def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
""" """
Instantiate a PreTrainedBertModel from a pre-trained model file. Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -109,16 +110,11 @@ class BertTokenizer(object): ...@@ -109,16 +110,11 @@ class BertTokenizer(object):
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
else: else:
vocab_file = pretrained_model_name vocab_file = pretrained_model_name
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_vocab_file = cached_path(vocab_file) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, do_lower_case)
except FileNotFoundError: except FileNotFoundError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
...@@ -126,8 +122,15 @@ class BertTokenizer(object): ...@@ -126,8 +122,15 @@ class BertTokenizer(object):
"associated to this path or url.".format( "associated to this path or url.".format(
pretrained_model_name, pretrained_model_name,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name)) vocab_file))
tokenizer = None return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer return tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment