Unverified Commit 3763f894 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #696 from huggingface/split_config_weights

Split config weights
parents a6f25118 f9647530
...@@ -516,7 +516,9 @@ Here is a detailed documentation of the classes in the package and how to use th ...@@ -516,7 +516,9 @@ Here is a detailed documentation of the classes in the package and how to use th
### Loading Google AI or OpenAI pre-trained weights or PyTorch dump ### Loading Google AI or OpenAI pre-trained weights or PyTorch dump
To load one of Google AI's, OpenAI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as ### `from_pretrained()` method
To load one of Google AI's, OpenAI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated using the `from_pretrained()` method:
```python ```python
model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None, from_tf=False, state_dict=None, *input, **kwargs) model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None, from_tf=False, state_dict=None, *input, **kwargs)
...@@ -581,6 +583,22 @@ model = GPT2Model.from_pretrained('gpt2') ...@@ -581,6 +583,22 @@ model = GPT2Model.from_pretrained('gpt2')
``` ```
#### Cache directory
`pytorch_pretrained_bert` save the pretrained weights in a cache directory which is located at (in this order of priority):
- `cache_dir` optional arguments to the `from_pretrained()` method (see above),
- shell environment variable `PYTORCH_PRETRAINED_BERT_CACHE`,
- PyTorch cache home + `/pytorch_pretrained_bert/`
where PyTorch cache home is defined by (in this order):
- shell environment variable `ENV_TORCH_HOME`
- shell environment variable `ENV_XDG_CACHE_HOME` + `/torch/`)
- default: `~/.cache/torch/`
Usually, if you don't set any specific environment variable, `pytorch_pretrained_bert` cache will be at `~/.cache/torch/pytorch_pretrained_bert/`.
You can alsways safely delete `pytorch_pretrained_bert` cache but the pretrained model weights and vocabulary files wil have to be re-downloaded from our S3.
### Serialization best-practices ### Serialization best-practices
This section explain how you can save and re-load a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL). This section explain how you can save and re-load a fine-tuned model (BERT, GPT, GPT-2 and Transformer-XL).
...@@ -590,6 +608,13 @@ There are three types of files you need to save to be able to reload a fine-tune ...@@ -590,6 +608,13 @@ There are three types of files you need to save to be able to reload a fine-tune
- the configuration file of the model which is saved as a JSON file, and - the configuration file of the model which is saved as a JSON file, and
- the vocabulary (and the merges for the BPE-based models GPT and GPT-2). - the vocabulary (and the merges for the BPE-based models GPT and GPT-2).
The defaults files names of these files are as follow:
- the model weights file: `pytorch_model.bin`,
- the configuration file: `config.json`,
- the vocabulary file: `vocab.txt` for BERT and Transformer-XL, `vocab.json` for GPT/GPT-2 (BPE vocabulary),
- for GPT/GPT-2 (BPE vocabulary) the additional merges file: `merges.txt`.
Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards: Here is the recommended way of saving the model, configuration and vocabulary to an `output_dir` directory and reloading the model and tokenizer afterwards:
```python ```python
...@@ -1432,6 +1457,25 @@ The results were similar to the above FP32 results (actually slightly higher): ...@@ -1432,6 +1457,25 @@ The results were similar to the above FP32 results (actually slightly higher):
{"exact_match": 84.65468306527909, "f1": 91.238669287002} {"exact_match": 84.65468306527909, "f1": 91.238669287002}
``` ```
Here is an example with the recent `bert-large-uncased-whole-word-masking`:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
run_squad.py \
--bert_model bert-large-uncased-whole-word-masking \
--do_train \
--do_predict \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2.0 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_squad/
```
## Notebooks ## Notebooks
We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model. We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
......
#!/usr/bin/env python3
import argparse
import logging
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
from pytorch_pretrained_bert import BertModel, BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased',
help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch_size", type=int, default=-1)
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
args = parser.parse_args()
print(args)
if args.batch_size == -1:
args.batch_size = 1
assert args.nsamples % args.batch_size == 0
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()
if args.length == -1:
args.length = model.config.n_ctx // 2
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
while True:
context_tokens = []
if not args.unconditional:
raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=context_tokens,
start_token=None,
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:, len(context_tokens):].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
else:
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=None,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:,1:].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if __name__ == '__main__':
run_model()
...@@ -22,9 +22,6 @@ import json ...@@ -22,9 +22,6 @@ import json
import logging import logging
import math import math
import os import os
import shutil
import tarfile
import tempfile
import sys import sys
from io import open from io import open
...@@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME ...@@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz", 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz", 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz", 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
} }
BERT_CONFIG_NAME = 'bert_config.json' BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
...@@ -642,8 +651,15 @@ class BertPreTrainedModel(nn.Module): ...@@ -642,8 +651,15 @@ class BertPreTrainedModel(nn.Module):
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
archive_file = pretrained_model_name_or_path if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, BERT_CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
...@@ -661,40 +677,60 @@ class BertPreTrainedModel(nn.Module): ...@@ -661,40 +677,60 @@ class BertPreTrainedModel(nn.Module):
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file)) archive_file))
return None return None
if resolved_archive_file == archive_file: try:
logger.info("loading archive file {}".format(archive_file)) resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else: else:
logger.info("loading archive file {} from cache at {}".format( logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file)) archive_file, resolved_archive_file))
tempdir = None logger.info("loading configuration file {} from cache at {}".format(
if os.path.isdir(resolved_archive_file) or from_tf: config_file, resolved_config_file))
serialization_dir = resolved_archive_file ### Switching to split config/weight files configuration
else: # tempdir = None
# Extract archive to temp dir # if os.path.isdir(resolved_archive_file) or from_tf:
tempdir = tempfile.mkdtemp() # serialization_dir = resolved_archive_file
logger.info("extracting archive file {} to temp dir {}".format( # else:
resolved_archive_file, tempdir)) # # Extract archive to temp dir
with tarfile.open(resolved_archive_file, 'r:gz') as archive: # tempdir = tempfile.mkdtemp()
archive.extractall(tempdir) # logger.info("extracting archive file {} to temp dir {}".format(
serialization_dir = tempdir # resolved_archive_file, tempdir))
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
# archive.extractall(tempdir)
# serialization_dir = tempdir
# config_file = os.path.join(serialization_dir, CONFIG_NAME)
# if not os.path.exists(config_file):
# # Backward compatibility with old naming format
# config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
# Load config # Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME) config = BertConfig.from_json_file(resolved_config_file)
if not os.path.exists(config_file):
# Backward compatibility with old naming format
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf: if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) # weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path, map_location='cpu') state_dict = torch.load(resolved_archive_file, map_location='cpu')
if tempdir: # if tempdir:
# Clean up temp dir # # Clean up temp dir
shutil.rmtree(tempdir) # shutil.rmtree(tempdir)
if from_tf: if from_tf:
# Directly load from a TensorFlow checkpoint # Directly load from a TensorFlow checkpoint
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) # weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
return load_tf_weights_in_bert(model, weights_path) return load_tf_weights_in_bert(model, weights_path)
# Load from a PyTorch state_dict # Load from a PyTorch state_dict
old_keys = [] old_keys = []
......
...@@ -23,9 +23,6 @@ import json ...@@ -23,9 +23,6 @@ import json
import logging import logging
import math import math
import os import os
import shutil
import tarfile
import tempfile
import sys import sys
from io import open from io import open
...@@ -496,7 +493,6 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -496,7 +493,6 @@ class GPT2PreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( logger.error(
...@@ -505,10 +501,27 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -505,10 +501,27 @@ class GPT2PreTrainedModel(nn.Module):
else: else:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file, config_file archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
) )
) )
return None return None
......
...@@ -23,9 +23,6 @@ import json ...@@ -23,9 +23,6 @@ import json
import logging import logging
import math import math
import os import os
import shutil
import tarfile
import tempfile
import sys import sys
from io import open from io import open
...@@ -499,7 +496,6 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -499,7 +496,6 @@ class OpenAIGPTPreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( logger.error(
...@@ -508,10 +504,27 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -508,10 +504,27 @@ class OpenAIGPTPreTrainedModel(nn.Module):
else: else:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file, config_file archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
) )
) )
return None return None
......
...@@ -25,9 +25,6 @@ import copy ...@@ -25,9 +25,6 @@ import copy
import json import json
import math import math
import logging import logging
import tarfile
import tempfile
import shutil
import collections import collections
import sys import sys
from io import open from io import open
...@@ -924,7 +921,6 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -924,7 +921,6 @@ class TransfoXLPreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( logger.error(
...@@ -933,12 +929,29 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -933,12 +929,29 @@ class TransfoXLPreTrainedModel(nn.Module):
else: else:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), config_file
pretrained_model_name_or_path, )
archive_file, config_file)) )
return None return None
if resolved_archive_file == archive_file and resolved_config_file == config_file: if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file)) logger.info("loading weights file {}".format(archive_file))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment