Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
382e2d1e
Commit
382e2d1e
authored
Jun 18, 2019
by
thomwolf
Browse files
spliting config and weight files for bert also
parent
a6f25118
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
158 additions
and
40 deletions
+158
-40
README.md
README.md
+19
-0
examples/bertology.py
examples/bertology.py
+92
-0
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+47
-31
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+0
-3
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+0
-3
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+0
-3
No files found.
README.md
View file @
382e2d1e
...
...
@@ -1432,6 +1432,25 @@ The results were similar to the above FP32 results (actually slightly higher):
{
"exact_match"
: 84.65468306527909,
"f1"
: 91.238669287002
}
```
Here is an example with the recent
`bert-large-uncased-whole-word-masking`
:
```
bash
python
-m
torch.distributed.launch
--nproc_per_node
=
8
\
run_squad.py
\
--bert_model
bert-large-uncased-whole-word-masking
\
--do_train
\
--do_predict
\
--do_lower_case
\
--train_file
$SQUAD_DIR
/train-v1.1.json
\
--predict_file
$SQUAD_DIR
/dev-v1.1.json
\
--train_batch_size
12
\
--learning_rate
3e-5
\
--num_train_epochs
2.0
\
--max_seq_length
384
\
--doc_stride
128
\
--output_dir
/tmp/debug_squad/
```
## Notebooks
We include
[
three Jupyter Notebooks
](
https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks
)
that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
...
...
examples/bertology.py
0 → 100644
View file @
382e2d1e
#!/usr/bin/env python3
import
argparse
import
logging
from
tqdm
import
trange
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
from
pytorch_pretrained_bert
import
BertModel
,
BertTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
run_model
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name_or_path'
,
type
=
str
,
default
=
'bert-base-uncased'
,
help
=
'pretrained model name or path to local checkpoint'
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--unconditional'
,
action
=
'store_true'
,
help
=
'If true, unconditional generation.'
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
batch_size
==
-
1
:
args
.
batch_size
=
1
assert
args
.
nsamples
%
args
.
batch_size
==
0
np
.
random
.
seed
(
args
.
seed
)
torch
.
random
.
manual_seed
(
args
.
seed
)
torch
.
cuda
.
manual_seed
(
args
.
seed
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
enc
=
GPT2Tokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
to
(
device
)
model
.
eval
()
if
args
.
length
==
-
1
:
args
.
length
=
model
.
config
.
n_ctx
//
2
elif
args
.
length
>
model
.
config
.
n_ctx
:
raise
ValueError
(
"Can't get samples longer than window size: %s"
%
model
.
config
.
n_ctx
)
while
True
:
context_tokens
=
[]
if
not
args
.
unconditional
:
raw_text
=
input
(
"Model prompt >>> "
)
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"Model prompt >>> "
)
context_tokens
=
enc
.
encode
(
raw_text
)
generated
=
0
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
out
=
sample_sequence
(
model
=
model
,
length
=
args
.
length
,
context
=
context_tokens
,
start_token
=
None
,
batch_size
=
args
.
batch_size
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
)
out
=
out
[:,
len
(
context_tokens
):].
tolist
()
for
i
in
range
(
args
.
batch_size
):
generated
+=
1
text
=
enc
.
decode
(
out
[
i
])
print
(
"="
*
40
+
" SAMPLE "
+
str
(
generated
)
+
" "
+
"="
*
40
)
print
(
text
)
print
(
"="
*
80
)
else
:
generated
=
0
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
out
=
sample_sequence
(
model
=
model
,
length
=
args
.
length
,
context
=
None
,
start_token
=
enc
.
encoder
[
'<|endoftext|>'
],
batch_size
=
args
.
batch_size
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
)
out
=
out
[:,
1
:].
tolist
()
for
i
in
range
(
args
.
batch_size
):
generated
+=
1
text
=
enc
.
decode
(
out
[
i
])
print
(
"="
*
40
+
" SAMPLE "
+
str
(
generated
)
+
" "
+
"="
*
40
)
print
(
text
)
print
(
"="
*
80
)
if
__name__
==
'__main__'
:
run_model
()
pytorch_pretrained_bert/modeling.py
View file @
382e2d1e
...
...
@@ -22,9 +22,6 @@ import json
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tempfile
import
sys
from
io
import
open
...
...
@@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz"
,
'bert-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz"
,
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz"
,
'bert-base-german-cased'
:
"https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz"
,
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin"
,
'bert-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin"
,
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin"
,
'bert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin"
,
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json"
,
'bert-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json"
,
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json"
,
'bert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json"
,
}
BERT_CONFIG_NAME
=
'bert_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
...
...
@@ -642,11 +651,14 @@ class BertPreTrainedModel(nn.Module):
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
pretrained_model_name_or_path
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
...
...
@@ -661,22 +673,26 @@ class BertPreTrainedModel(nn.Module):
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading
archive
file {} from cache at {}"
.
format
(
logger
.
info
(
"loading
weights
file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
tempdir
=
None
if
os
.
path
.
isdir
(
resolved_archive_file
)
or
from_tf
:
serialization_dir
=
resolved_archive_file
else
:
# Extract archive to temp dir
tempdir
=
tempfile
.
mkdtemp
()
logger
.
info
(
"extracting archive file {} to temp dir {}"
.
format
(
resolved_archive_file
,
tempdir
))
with
tarfile
.
open
(
resolved_archive_file
,
'r:gz'
)
as
archive
:
archive
.
extractall
(
tempdir
)
serialization_dir
=
tempdir
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
### Switching to split config/weight files configuration
# tempdir = None
# if os.path.isdir(resolved_archive_file) or from_tf:
# serialization_dir = resolved_archive_file
# else:
# # Extract archive to temp dir
# tempdir = tempfile.mkdtemp()
# logger.info("extracting archive file {} to temp dir {}".format(
# resolved_archive_file, tempdir))
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
# archive.extractall(tempdir)
# serialization_dir = tempdir
# Load config
config_file
=
os
.
path
.
join
(
serialization_dir
,
CONFIG_NAME
)
if
not
os
.
path
.
exists
(
config_file
):
...
...
@@ -689,9 +705,9 @@ class BertPreTrainedModel(nn.Module):
if
state_dict
is
None
and
not
from_tf
:
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
weights_path
,
map_location
=
'cpu'
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
#
if tempdir:
#
# Clean up temp dir
#
shutil.rmtree(tempdir)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
weights_path
=
os
.
path
.
join
(
serialization_dir
,
TF_WEIGHTS_NAME
)
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
382e2d1e
...
...
@@ -23,9 +23,6 @@ import json
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tempfile
import
sys
from
io
import
open
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
382e2d1e
...
...
@@ -23,9 +23,6 @@ import json
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tempfile
import
sys
from
io
import
open
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
382e2d1e
...
...
@@ -25,9 +25,6 @@ import copy
import
json
import
math
import
logging
import
tarfile
import
tempfile
import
shutil
import
collections
import
sys
from
io
import
open
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment