Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6a083fd4
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dbd041243d3c6827ae3dfad5e53619ee7838853c"
Commit
6a083fd4
authored
Sep 18, 2019
by
thomwolf
Browse files
update pt-tf conversion script
parent
f6969cc1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
4 deletions
+24
-4
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
+8
-4
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+16
-0
No files found.
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
6a083fd4
...
...
@@ -102,7 +102,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model
.
save_weights
(
tf_dump_path
,
save_format
=
'h5'
)
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
compare_with_pt_model
=
False
):
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
False
):
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
if
args_model_type
is
None
:
...
...
@@ -126,8 +126,8 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
if
'finetuned'
in
shortcut_name
:
print
(
" Skipping finetuned checkpoint "
)
continue
config_file
=
cached_path
(
aws_config_map
[
shortcut_name
],
force_download
=
True
)
model_file
=
cached_path
(
aws_model_maps
[
shortcut_name
],
force_download
=
True
)
config_file
=
cached_path
(
aws_config_map
[
shortcut_name
],
force_download
=
not
use_cached_models
)
model_file
=
cached_path
(
aws_model_maps
[
shortcut_name
],
force_download
=
not
use_cached_models
)
convert_pt_checkpoint_to_tf
(
model_type
,
model_file
,
...
...
@@ -165,6 +165,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--compare_with_pt_model"
,
action
=
'store_true'
,
help
=
"Compare Tensorflow and PyTorch model predictions."
)
parser
.
add_argument
(
"--use_cached_models"
,
action
=
'store_true'
,
help
=
"Use cached models if possible instead of updating to latest checkpoint versions."
)
args
=
parser
.
parse_args
()
if
args
.
pytorch_checkpoint_path
is
not
None
:
...
...
@@ -176,4 +179,5 @@ if __name__ == "__main__":
else
:
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
compare_with_pt_model
=
args
.
compare_with_pt_model
,
use_cached_models
=
args
.
use_cached_models
)
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
6a083fd4
...
...
@@ -78,6 +78,12 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
pt_state_dict
[
new_key
]
=
pt_state_dict
.
pop
(
old_key
)
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
# TF models always have a prefix, some of PyTorch models (base ones) don't
start_prefix_to_remove
=
''
if
not
any
(
s
.
startswith
(
tf_model
.
base_model_prefix
)
for
s
in
pt_state_dict
.
keys
()):
start_prefix_to_remove
=
tf_model
.
base_model_prefix
+
'.'
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
weight_value_tuples
=
[]
...
...
@@ -100,13 +106,23 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
if
name
[
-
1
]
==
'beta'
:
name
[
-
1
]
=
'bias'
# Remove prefix if needed
name
=
'.'
.
join
(
name
)
if
start_prefix_to_remove
:
name
=
name
.
replace
(
start_prefix_to_remove
,
''
,
1
)
# Find associated numpy array in pytorch model state dict
assert
name
in
pt_state_dict
,
"{} not found in PyTorch model"
.
format
(
name
)
array
=
pt_state_dict
[
name
].
numpy
()
if
transpose
:
array
=
numpy
.
transpose
(
array
)
if
len
(
symbolic_weight
.
shape
)
<
len
(
array
.
shape
):
array
=
numpy
.
squeeze
(
array
)
elif
len
(
symbolic_weight
.
shape
)
>
len
(
array
.
shape
):
array
=
numpy
.
expand_dims
(
array
,
axis
=
0
)
try
:
assert
list
(
symbolic_weight
.
shape
)
==
list
(
array
.
shape
)
except
AssertionError
as
e
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment