Commit 382e2d1e authored by thomwolf's avatar thomwolf
Browse files

spliting config and weight files for bert also

parent a6f25118
...@@ -1432,6 +1432,25 @@ The results were similar to the above FP32 results (actually slightly higher): ...@@ -1432,6 +1432,25 @@ The results were similar to the above FP32 results (actually slightly higher):
{"exact_match": 84.65468306527909, "f1": 91.238669287002} {"exact_match": 84.65468306527909, "f1": 91.238669287002}
``` ```
Here is an example with the recent `bert-large-uncased-whole-word-masking`:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
run_squad.py \
--bert_model bert-large-uncased-whole-word-masking \
--do_train \
--do_predict \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2.0 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_squad/
```
## Notebooks ## Notebooks
We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model. We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
......
#!/usr/bin/env python3
import argparse
import logging
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
from pytorch_pretrained_bert import BertModel, BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased',
help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch_size", type=int, default=-1)
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
args = parser.parse_args()
print(args)
if args.batch_size == -1:
args.batch_size = 1
assert args.nsamples % args.batch_size == 0
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()
if args.length == -1:
args.length = model.config.n_ctx // 2
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
while True:
context_tokens = []
if not args.unconditional:
raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=context_tokens,
start_token=None,
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:, len(context_tokens):].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
else:
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=None,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:,1:].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if __name__ == '__main__':
run_model()
...@@ -22,9 +22,6 @@ import json ...@@ -22,9 +22,6 @@ import json
import logging import logging
import math import math
import os import os
import shutil
import tarfile
import tempfile
import sys import sys
from io import open from io import open
...@@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME ...@@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz", 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz", 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz", 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
} }
BERT_CONFIG_NAME = 'bert_config.json' BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
...@@ -642,11 +651,14 @@ class BertPreTrainedModel(nn.Module): ...@@ -642,11 +651,14 @@ class BertPreTrainedModel(nn.Module):
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
archive_file = pretrained_model_name_or_path archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( logger.error(
...@@ -661,22 +673,26 @@ class BertPreTrainedModel(nn.Module): ...@@ -661,22 +673,26 @@ class BertPreTrainedModel(nn.Module):
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file)) archive_file))
return None return None
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading archive file {}".format(archive_file)) logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else: else:
logger.info("loading archive file {} from cache at {}".format( logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file)) archive_file, resolved_archive_file))
tempdir = None logger.info("loading configuration file {} from cache at {}".format(
if os.path.isdir(resolved_archive_file) or from_tf: config_file, resolved_config_file))
serialization_dir = resolved_archive_file ### Switching to split config/weight files configuration
else: # tempdir = None
# Extract archive to temp dir # if os.path.isdir(resolved_archive_file) or from_tf:
tempdir = tempfile.mkdtemp() # serialization_dir = resolved_archive_file
logger.info("extracting archive file {} to temp dir {}".format( # else:
resolved_archive_file, tempdir)) # # Extract archive to temp dir
with tarfile.open(resolved_archive_file, 'r:gz') as archive: # tempdir = tempfile.mkdtemp()
archive.extractall(tempdir) # logger.info("extracting archive file {} to temp dir {}".format(
serialization_dir = tempdir # resolved_archive_file, tempdir))
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
# archive.extractall(tempdir)
# serialization_dir = tempdir
# Load config # Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME) config_file = os.path.join(serialization_dir, CONFIG_NAME)
if not os.path.exists(config_file): if not os.path.exists(config_file):
...@@ -689,9 +705,9 @@ class BertPreTrainedModel(nn.Module): ...@@ -689,9 +705,9 @@ class BertPreTrainedModel(nn.Module):
if state_dict is None and not from_tf: if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path, map_location='cpu') state_dict = torch.load(weights_path, map_location='cpu')
if tempdir: # if tempdir:
# Clean up temp dir # # Clean up temp dir
shutil.rmtree(tempdir) # shutil.rmtree(tempdir)
if from_tf: if from_tf:
# Directly load from a TensorFlow checkpoint # Directly load from a TensorFlow checkpoint
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
......
...@@ -23,9 +23,6 @@ import json ...@@ -23,9 +23,6 @@ import json
import logging import logging
import math import math
import os import os
import shutil
import tarfile
import tempfile
import sys import sys
from io import open from io import open
......
...@@ -23,9 +23,6 @@ import json ...@@ -23,9 +23,6 @@ import json
import logging import logging
import math import math
import os import os
import shutil
import tarfile
import tempfile
import sys import sys
from io import open from io import open
......
...@@ -25,9 +25,6 @@ import copy ...@@ -25,9 +25,6 @@ import copy
import json import json
import math import math
import logging import logging
import tarfile
import tempfile
import shutil
import collections import collections
import sys import sys
from io import open from io import open
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment