Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a84adddd
Commit
a84adddd
authored
Sep 12, 2019
by
thomwolf
Browse files
convert all models
parent
969d3ae9
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1184 additions
and
27 deletions
+1184
-27
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
+76
-27
pytorch_transformers/modeling_tf_transfo_xl.py
pytorch_transformers/modeling_tf_transfo_xl.py
+1108
-0
No files found.
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
a84adddd
...
@@ -18,10 +18,11 @@ from __future__ import absolute_import
...
@@ -18,10 +18,11 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
argparse
import
argparse
import
tensorflow
as
tf
import
tensorflow
as
tf
from
pytorch_transformers
import
is_torch_available
from
pytorch_transformers
import
is_torch_available
,
cached_path
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
...
@@ -31,26 +32,36 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt
...
@@ -31,26 +32,36 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
pytorch_transformers
import
BertForPreTraining
,
GPT2LMHeadModel
,
XLNetLMHeadModel
,
XLMWithLMHeadModel
from
pytorch_transformers
import
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,)
else
:
else
:
BertForPreTraining
,
GPT2LMHeadModel
=
None
,
None
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,)
import
logging
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
MODEL_CLASSES
=
{
MODEL_CLASSES
=
{
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
),
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
}
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
if
model_type
not
in
MODEL_CLASSES
:
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
=
MODEL_CLASSES
[
model_type
]
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
# Initialise TF model
# Initialise TF model
config
=
config_class
.
from_json_file
(
config_file
)
config
=
config_class
.
from_json_file
(
config_file
)
...
@@ -68,8 +79,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -68,8 +79,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
pt_model
=
pt_model_class
.
from_pretrained
(
None
,
pt_model
=
pt_model_class
.
from_pretrained
(
None
,
config
=
config
,
config
=
config
,
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'cpu'
))
map_location
=
'cpu'
))
pt_inputs
=
torch
.
tensor
(
inputs_list
)
pt_inputs
=
torch
.
tensor
(
inputs_list
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -79,42 +90,80 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -79,42 +90,80 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
np_tf
=
tfo
[
0
].
numpy
()
np_tf
=
tfo
[
0
].
numpy
()
diff
=
np
.
amax
(
np
.
abs
(
np_pt
-
np_tf
))
diff
=
np
.
amax
(
np
.
abs
(
np_pt
-
np_tf
))
print
(
"Max absolute difference between models outputs {}"
.
format
(
diff
))
print
(
"Max absolute difference between models outputs {}"
.
format
(
diff
))
assert
diff
<=
1e-3
,
"Error, model absolute difference is >1e-3"
# Save pytorch-model
# Save pytorch-model
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
tf_model
.
save_weights
(
tf_dump_path
)
tf_model
.
save_weights
(
tf_dump_path
,
save_format
=
'h5'
)
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
compare_with_pt_model
=
False
):
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
if
args_model_type
is
None
:
model_types
=
list
(
MODEL_CLASSES
.
keys
())
else
:
model_types
=
[
args_model_type
]
for
j
,
model_type
in
enumerate
(
model_types
,
start
=
1
):
print
(
"="
*
100
)
print
(
" Converting model type {}/{}: {}"
.
format
(
j
,
len
(
model_types
),
model_type
))
print
(
"="
*
100
)
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type {}, should be one of {}."
.
format
(
model_type
,
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
for
i
,
shortcut_name
in
enumerate
(
aws_config_map
.
keys
(),
start
=
1
):
print
(
"-"
*
100
)
print
(
" Converting checkpoint {}/{}: {}"
.
format
(
i
,
len
(
aws_config_map
),
shortcut_name
))
print
(
"-"
*
100
)
config_file
=
cached_path
(
aws_config_map
[
shortcut_name
],
force_download
=
True
)
model_file
=
cached_path
(
aws_model_maps
[
shortcut_name
],
force_download
=
True
)
convert_pt_checkpoint_to_tf
(
model_type
,
model_file
,
config_file
,
os
.
path
.
join
(
tf_dump_path
,
shortcut_name
+
'-tf_model.h5'
),
compare_with_pt_model
=
compare_with_pt_model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
## Required parameters
## Required parameters
parser
.
add_argument
(
"--
model_type
"
,
parser
.
add_argument
(
"--
tf_dump_path
"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Model type selcted in the list of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
help
=
"Path to the output Tensorflow dump file."
)
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
help
=
"Model type selected in the list of {}. If not given, will download and convert all the models from AWS."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
parser
.
add_argument
(
"--pytorch_checkpoint_path"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
help
=
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
help
=
"Path to the PyTorch checkpoint path
."
)
"If not given, will download and convert all the checkpoints from AWS
."
)
parser
.
add_argument
(
"--config_file"
,
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained model.
\n
"
help
=
"The config json file corresponding to the pre-trained model.
\n
"
"This specifies the model architecture."
)
"This specifies the model architecture. If not given and "
parser
.
add_argument
(
"--tf_dump_path"
,
"--pytorch_checkpoint_path is not given or is a shortcut name"
default
=
None
,
"use the configuration associated to teh shortcut name on the AWS"
)
type
=
str
,
required
=
True
,
help
=
"Path to the output Tensorflow dump file."
)
parser
.
add_argument
(
"--compare_with_pt_model"
,
parser
.
add_argument
(
"--compare_with_pt_model"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"Compare Tensorflow and PyTorch model predictions."
)
help
=
"Compare Tensorflow and PyTorch model predictions."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_pt_checkpoint_to_tf
(
args
.
model_type
.
lower
(),
args
.
pytorch_checkpoint_path
,
if
args
.
pytorch_checkpoint_path
is
not
None
:
args
.
config_file
,
convert_pt_checkpoint_to_tf
(
args
.
model_type
.
lower
(),
args
.
tf_dump_path
,
args
.
pytorch_checkpoint_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
args
.
config_file
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
else
:
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
pytorch_transformers/modeling_tf_transfo_xl.py
0 → 100644
View file @
a84adddd
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment