Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a84adddd
Commit
a84adddd
authored
Sep 12, 2019
by
thomwolf
Browse files
convert all models
parent
969d3ae9
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1184 additions
and
27 deletions
+1184
-27
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
+76
-27
pytorch_transformers/modeling_tf_transfo_xl.py
pytorch_transformers/modeling_tf_transfo_xl.py
+1108
-0
No files found.
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
a84adddd
...
...
@@ -18,10 +18,11 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
argparse
import
tensorflow
as
tf
from
pytorch_transformers
import
is_torch_available
from
pytorch_transformers
import
is_torch_available
,
cached_path
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
...
...
@@ -31,26 +32,36 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt
if
is_torch_available
():
import
torch
import
numpy
as
np
from
pytorch_transformers
import
BertForPreTraining
,
GPT2LMHeadModel
,
XLNetLMHeadModel
,
XLMWithLMHeadModel
from
pytorch_transformers
import
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,)
else
:
BertForPreTraining
,
GPT2LMHeadModel
=
None
,
None
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,)
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
MODEL_CLASSES
=
{
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
),
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
=
MODEL_CLASSES
[
model_type
]
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
# Initialise TF model
config
=
config_class
.
from_json_file
(
config_file
)
...
...
@@ -79,42 +90,80 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
np_tf
=
tfo
[
0
].
numpy
()
diff
=
np
.
amax
(
np
.
abs
(
np_pt
-
np_tf
))
print
(
"Max absolute difference between models outputs {}"
.
format
(
diff
))
assert
diff
<=
1e-3
,
"Error, model absolute difference is >1e-3"
# Save pytorch-model
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
tf_model
.
save_weights
(
tf_dump_path
)
tf_model
.
save_weights
(
tf_dump_path
,
save_format
=
'h5'
)
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
compare_with_pt_model
=
False
):
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
if
args_model_type
is
None
:
model_types
=
list
(
MODEL_CLASSES
.
keys
())
else
:
model_types
=
[
args_model_type
]
for
j
,
model_type
in
enumerate
(
model_types
,
start
=
1
):
print
(
"="
*
100
)
print
(
" Converting model type {}/{}: {}"
.
format
(
j
,
len
(
model_types
),
model_type
))
print
(
"="
*
100
)
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type {}, should be one of {}."
.
format
(
model_type
,
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
for
i
,
shortcut_name
in
enumerate
(
aws_config_map
.
keys
(),
start
=
1
):
print
(
"-"
*
100
)
print
(
" Converting checkpoint {}/{}: {}"
.
format
(
i
,
len
(
aws_config_map
),
shortcut_name
))
print
(
"-"
*
100
)
config_file
=
cached_path
(
aws_config_map
[
shortcut_name
],
force_download
=
True
)
model_file
=
cached_path
(
aws_model_maps
[
shortcut_name
],
force_download
=
True
)
convert_pt_checkpoint_to_tf
(
model_type
,
model_file
,
config_file
,
os
.
path
.
join
(
tf_dump_path
,
shortcut_name
+
'-tf_model.h5'
),
compare_with_pt_model
=
compare_with_pt_model
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--
model_type
"
,
parser
.
add_argument
(
"--
tf_dump_path
"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Model type selcted in the list of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
help
=
"Path to the output Tensorflow dump file."
)
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
help
=
"Model type selected in the list of {}. If not given, will download and convert all the models from AWS."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the PyTorch checkpoint path
."
)
help
=
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS
."
)
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--tf_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output Tensorflow dump file."
)
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name"
"use the configuration associated to teh shortcut name on the AWS"
)
parser
.
add_argument
(
"--compare_with_pt_model"
,
action
=
'store_true'
,
help
=
"Compare Tensorflow and PyTorch model predictions."
)
args
=
parser
.
parse_args
()
if
args
.
pytorch_checkpoint_path
is
not
None
:
convert_pt_checkpoint_to_tf
(
args
.
model_type
.
lower
(),
args
.
pytorch_checkpoint_path
,
args
.
config_file
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
else
:
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
pytorch_transformers/modeling_tf_transfo_xl.py
0 → 100644
View file @
a84adddd
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment