Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
382e2d1e
"docker/vscode:/vscode.git/clone" did not exist on "7829c890db958279ca49519cc009e4f2def3fccb"
Commit
382e2d1e
authored
Jun 18, 2019
by
thomwolf
Browse files
spliting config and weight files for bert also
parent
a6f25118
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
158 additions
and
40 deletions
+158
-40
README.md
README.md
+19
-0
examples/bertology.py
examples/bertology.py
+92
-0
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+47
-31
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+0
-3
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+0
-3
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+0
-3
No files found.
README.md
View file @
382e2d1e
...
...
@@ -1432,6 +1432,25 @@ The results were similar to the above FP32 results (actually slightly higher):
{
"exact_match"
: 84.65468306527909,
"f1"
: 91.238669287002
}
```
Here is an example with the recent
`bert-large-uncased-whole-word-masking`
:
```
bash
python
-m
torch.distributed.launch
--nproc_per_node
=
8
\
run_squad.py
\
--bert_model
bert-large-uncased-whole-word-masking
\
--do_train
\
--do_predict
\
--do_lower_case
\
--train_file
$SQUAD_DIR
/train-v1.1.json
\
--predict_file
$SQUAD_DIR
/dev-v1.1.json
\
--train_batch_size
12
\
--learning_rate
3e-5
\
--num_train_epochs
2.0
\
--max_seq_length
384
\
--doc_stride
128
\
--output_dir
/tmp/debug_squad/
```
## Notebooks
We include
[
three Jupyter Notebooks
](
https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks
)
that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
...
...
examples/bertology.py
0 → 100644
View file @
382e2d1e
#!/usr/bin/env python3
import
argparse
import
logging
from
tqdm
import
trange
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
from
pytorch_pretrained_bert
import
BertModel
,
BertTokenizer
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
def
run_model
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name_or_path'
,
type
=
str
,
default
=
'bert-base-uncased'
,
help
=
'pretrained model name or path to local checkpoint'
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--unconditional'
,
action
=
'store_true'
,
help
=
'If true, unconditional generation.'
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
batch_size
==
-
1
:
args
.
batch_size
=
1
assert
args
.
nsamples
%
args
.
batch_size
==
0
np
.
random
.
seed
(
args
.
seed
)
torch
.
random
.
manual_seed
(
args
.
seed
)
torch
.
cuda
.
manual_seed
(
args
.
seed
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
enc
=
GPT2Tokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
to
(
device
)
model
.
eval
()
if
args
.
length
==
-
1
:
args
.
length
=
model
.
config
.
n_ctx
//
2
elif
args
.
length
>
model
.
config
.
n_ctx
:
raise
ValueError
(
"Can't get samples longer than window size: %s"
%
model
.
config
.
n_ctx
)
while
True
:
context_tokens
=
[]
if
not
args
.
unconditional
:
raw_text
=
input
(
"Model prompt >>> "
)
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"Model prompt >>> "
)
context_tokens
=
enc
.
encode
(
raw_text
)
generated
=
0
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
out
=
sample_sequence
(
model
=
model
,
length
=
args
.
length
,
context
=
context_tokens
,
start_token
=
None
,
batch_size
=
args
.
batch_size
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
)
out
=
out
[:,
len
(
context_tokens
):].
tolist
()
for
i
in
range
(
args
.
batch_size
):
generated
+=
1
text
=
enc
.
decode
(
out
[
i
])
print
(
"="
*
40
+
" SAMPLE "
+
str
(
generated
)
+
" "
+
"="
*
40
)
print
(
text
)
print
(
"="
*
80
)
else
:
generated
=
0
for
_
in
range
(
args
.
nsamples
//
args
.
batch_size
):
out
=
sample_sequence
(
model
=
model
,
length
=
args
.
length
,
context
=
None
,
start_token
=
enc
.
encoder
[
'<|endoftext|>'
],
batch_size
=
args
.
batch_size
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
device
=
device
)
out
=
out
[:,
1
:].
tolist
()
for
i
in
range
(
args
.
batch_size
):
generated
+=
1
text
=
enc
.
decode
(
out
[
i
])
print
(
"="
*
40
+
" SAMPLE "
+
str
(
generated
)
+
" "
+
"="
*
40
)
print
(
text
)
print
(
"="
*
80
)
if
__name__
==
'__main__'
:
run_model
()
pytorch_pretrained_bert/modeling.py
View file @
382e2d1e
...
...
@@ -22,9 +22,6 @@ import json
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tempfile
import
sys
from
io
import
open
...
...
@@ -37,16 +34,28 @@ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz"
,
'bert-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz"
,
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz"
,
'bert-base-german-cased'
:
"https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz"
,
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin"
,
'bert-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin"
,
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin"
,
'bert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin"
,
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json"
,
'bert-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json"
,
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json"
,
'bert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json"
,
}
BERT_CONFIG_NAME
=
'bert_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
...
...
@@ -642,11 +651,14 @@ class BertPreTrainedModel(nn.Module):
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
pretrained_model_name_or_path
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
...
...
@@ -661,22 +673,26 @@ class BertPreTrainedModel(nn.Module):
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading
archive
file {} from cache at {}"
.
format
(
logger
.
info
(
"loading
weights
file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
tempdir
=
None
if
os
.
path
.
isdir
(
resolved_archive_file
)
or
from_tf
:
serialization_dir
=
resolved_archive_file
else
:
# Extract archive to temp dir
tempdir
=
tempfile
.
mkdtemp
()
logger
.
info
(
"extracting archive file {} to temp dir {}"
.
format
(
resolved_archive_file
,
tempdir
))
with
tarfile
.
open
(
resolved_archive_file
,
'r:gz'
)
as
archive
:
archive
.
extractall
(
tempdir
)
serialization_dir
=
tempdir
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
### Switching to split config/weight files configuration
# tempdir = None
# if os.path.isdir(resolved_archive_file) or from_tf:
# serialization_dir = resolved_archive_file
# else:
# # Extract archive to temp dir
# tempdir = tempfile.mkdtemp()
# logger.info("extracting archive file {} to temp dir {}".format(
# resolved_archive_file, tempdir))
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
# archive.extractall(tempdir)
# serialization_dir = tempdir
# Load config
config_file
=
os
.
path
.
join
(
serialization_dir
,
CONFIG_NAME
)
if
not
os
.
path
.
exists
(
config_file
):
...
...
@@ -689,9 +705,9 @@ class BertPreTrainedModel(nn.Module):
if
state_dict
is
None
and
not
from_tf
:
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
weights_path
,
map_location
=
'cpu'
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
#
if tempdir:
#
# Clean up temp dir
#
shutil.rmtree(tempdir)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
weights_path
=
os
.
path
.
join
(
serialization_dir
,
TF_WEIGHTS_NAME
)
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
382e2d1e
...
...
@@ -23,9 +23,6 @@ import json
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tempfile
import
sys
from
io
import
open
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
382e2d1e
...
...
@@ -23,9 +23,6 @@ import json
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tempfile
import
sys
from
io
import
open
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
382e2d1e
...
...
@@ -25,9 +25,6 @@ import copy
import
json
import
math
import
logging
import
tarfile
import
tempfile
import
shutil
import
collections
import
sys
from
io
import
open
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment