Commit edcb56fd authored by thomwolf's avatar thomwolf
Browse files

more explicit variable name

parent 6bc082da
...@@ -512,14 +512,14 @@ class BertPreTrainedModel(nn.Module): ...@@ -512,14 +512,14 @@ class BertPreTrainedModel(nn.Module):
module.bias.data.zero_() module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
from_tf=False, *inputs, **kwargs): from_tf=False, *inputs, **kwargs):
""" """
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
Params: Params:
pretrained_model_name: either: pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of: - a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased` . `bert-base-uncased`
. `bert-large-uncased` . `bert-large-uncased`
...@@ -540,10 +540,10 @@ class BertPreTrainedModel(nn.Module): ...@@ -540,10 +540,10 @@ class BertPreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification) (ex: num_labels for BertForSequenceClassification)
""" """
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
archive_file = pretrained_model_name archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
...@@ -552,7 +552,7 @@ class BertPreTrainedModel(nn.Module): ...@@ -552,7 +552,7 @@ class BertPreTrainedModel(nn.Module):
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format( "associated to this path or url.".format(
pretrained_model_name, pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file)) archive_file))
return None return None
......
...@@ -418,14 +418,14 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -418,14 +418,14 @@ class OpenAIGPTPreTrainedModel(nn.Module):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, pretrained_model_name, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
): ):
""" """
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
Params: Params:
pretrained_model_name: either: pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of: - a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt` . `openai-gpt`
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
...@@ -440,11 +440,11 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -440,11 +440,11 @@ class OpenAIGPTPreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification) (ex: num_labels for BertForSequenceClassification)
""" """
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
archive_file = pretrained_model_name archive_file = pretrained_model_name_or_path
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
...@@ -455,7 +455,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -455,7 +455,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file, config_file archive_file, config_file
) )
) )
......
...@@ -116,15 +116,15 @@ class BertTokenizer(object): ...@@ -116,15 +116,15 @@ class BertTokenizer(object):
return tokens return tokens
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
""" """
Instantiate a PreTrainedBertModel from a pre-trained model file. Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
""" """
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
vocab_file = pretrained_model_name vocab_file = pretrained_model_name_or_path
if os.path.isdir(vocab_file): if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME) vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
...@@ -135,7 +135,7 @@ class BertTokenizer(object): ...@@ -135,7 +135,7 @@ class BertTokenizer(object):
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format( "associated to this path or url.".format(
pretrained_model_name, pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file)) vocab_file))
return None return None
...@@ -144,10 +144,10 @@ class BertTokenizer(object): ...@@ -144,10 +144,10 @@ class BertTokenizer(object):
else: else:
logger.info("loading vocabulary file {} from cache at {}".format( logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file)) vocab_file, resolved_vocab_file))
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings # than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer. # Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment