Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1d87b37d
Commit
1d87b37d
authored
Dec 06, 2019
by
thomwolf
Browse files
updating
parent
f8fb4335
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
7 deletions
+13
-7
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+13
-7
No files found.
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
1d87b37d
...
@@ -119,10 +119,10 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -119,10 +119,10 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
pt_model
=
pt_model_class
(
config
)
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'cpu'
)
pt_model
.
load_state_dict
(
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'cpu'
)
,
pt_model
=
pt_model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
strict
-
False
)
config
=
config
,
pt_model
.
eval
(
)
state_dict
=
state_dict
)
pt_inputs
=
torch
.
tensor
(
inputs_list
)
pt_inputs
=
torch
.
tensor
(
inputs_list
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -140,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -140,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
model_shortcut_names_or_path
=
None
,
config_shortcut_names_or_path
=
None
,
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
model_shortcut_names_or_path
=
None
,
config_shortcut_names_or_path
=
None
,
compare_with_pt_model
=
False
,
use_cached_models
=
False
,
only_convert_finetuned_models
=
False
):
compare_with_pt_model
=
False
,
use_cached_models
=
False
,
remove_cached_files
=
False
,
only_convert_finetuned_models
=
False
):
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
if
args_model_type
is
None
:
if
args_model_type
is
None
:
...
@@ -188,11 +188,13 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
...
@@ -188,11 +188,13 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if
os
.
path
.
isfile
(
model_shortcut_name
):
if
os
.
path
.
isfile
(
model_shortcut_name
):
model_shortcut_name
=
'converted_model'
model_shortcut_name
=
'converted_model'
convert_pt_checkpoint_to_tf
(
model_type
=
model_type
,
convert_pt_checkpoint_to_tf
(
model_type
=
model_type
,
pytorch_checkpoint_path
=
model_file
,
pytorch_checkpoint_path
=
model_file
,
config_file
=
config_file
,
config_file
=
config_file
,
tf_dump_path
=
os
.
path
.
join
(
tf_dump_path
,
model_shortcut_name
+
'-tf_model.h5'
),
tf_dump_path
=
os
.
path
.
join
(
tf_dump_path
,
model_shortcut_name
+
'-tf_model.h5'
),
compare_with_pt_model
=
compare_with_pt_model
)
compare_with_pt_model
=
compare_with_pt_model
)
if
remove_cached_files
:
os
.
remove
(
config_file
)
os
.
remove
(
config_file
)
os
.
remove
(
model_file
)
os
.
remove
(
model_file
)
...
@@ -227,6 +229,9 @@ if __name__ == "__main__":
...
@@ -227,6 +229,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--use_cached_models"
,
parser
.
add_argument
(
"--use_cached_models"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"Use cached models if possible instead of updating to latest checkpoint versions."
)
help
=
"Use cached models if possible instead of updating to latest checkpoint versions."
)
parser
.
add_argument
(
"--remove_cached_files"
,
action
=
'store_true'
,
help
=
"Remove pytorch models after conversion (save memory when converting in batches)."
)
parser
.
add_argument
(
"--only_convert_finetuned_models"
,
parser
.
add_argument
(
"--only_convert_finetuned_models"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"Only convert finetuned models."
)
help
=
"Only convert finetuned models."
)
...
@@ -246,4 +251,5 @@ if __name__ == "__main__":
...
@@ -246,4 +251,5 @@ if __name__ == "__main__":
config_shortcut_names_or_path
=
[
args
.
config_file
]
if
args
.
config_file
is
not
None
else
None
,
config_shortcut_names_or_path
=
[
args
.
config_file
]
if
args
.
config_file
is
not
None
else
None
,
compare_with_pt_model
=
args
.
compare_with_pt_model
,
compare_with_pt_model
=
args
.
compare_with_pt_model
,
use_cached_models
=
args
.
use_cached_models
,
use_cached_models
=
args
.
use_cached_models
,
remove_cached_files
=
args
.
remove_cached_files
,
only_convert_finetuned_models
=
args
.
only_convert_finetuned_models
)
only_convert_finetuned_models
=
args
.
only_convert_finetuned_models
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment