Commit 382e2d1e authored by thomwolf's avatar thomwolf
Browse files

spliting config and weight files for bert also

parent a6f25118
......@@ -1432,6 +1432,25 @@ The results were similar to the above FP32 results (actually slightly higher):
{"exact_match": 84.65468306527909, "f1": 91.238669287002}
```
Here is an example with the recent `bert-large-uncased-whole-word-masking`:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
run_squad.py \
--bert_model bert-large-uncased-whole-word-masking \
--do_train \
--do_predict \
--do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \
--train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2.0 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_squad/
```
## Notebooks
We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
......
#!/usr/bin/env python3
import argparse
import logging
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
from pytorch_pretrained_bert import BertModel, BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased',
help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch_size", type=int, default=-1)
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
args = parser.parse_args()
print(args)
if args.batch_size == -1:
args.batch_size = 1
assert args.nsamples % args.batch_size == 0
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()
if args.length == -1:
args.length = model.config.n_ctx // 2
elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
while True:
context_tokens = []
if not args.unconditional:
raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=context_tokens,
start_token=None,
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:, len(context_tokens):].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
else:
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=None,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:,1:].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if __name__ == '__main__':
run_model()
......@@ -22,9 +22,6 @@ import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open
......@@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz",
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
}
BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
......@@ -642,11 +651,14 @@ class BertPreTrainedModel(nn.Module):
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = pretrained_model_name_or_path
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
......@@ -661,22 +673,26 @@ class BertPreTrainedModel(nn.Module):
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading archive file {} from cache at {}".format(
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file) or from_tf:
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
### Switching to split config/weight files configuration
# tempdir = None
# if os.path.isdir(resolved_archive_file) or from_tf:
# serialization_dir = resolved_archive_file
# else:
# # Extract archive to temp dir
# tempdir = tempfile.mkdtemp()
# logger.info("extracting archive file {} to temp dir {}".format(
# resolved_archive_file, tempdir))
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
# archive.extractall(tempdir)
# serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
if not os.path.exists(config_file):
......@@ -689,9 +705,9 @@ class BertPreTrainedModel(nn.Module):
if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path, map_location='cpu')
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
# if tempdir:
# # Clean up temp dir
# shutil.rmtree(tempdir)
if from_tf:
# Directly load from a TensorFlow checkpoint
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
......
......@@ -23,9 +23,6 @@ import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open
......
......@@ -23,9 +23,6 @@ import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open
......
......@@ -25,9 +25,6 @@ import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import collections
import sys
from io import open
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment