Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e6cff60b
Unverified
Commit
e6cff60b
authored
Dec 10, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 10, 2019
Browse files
Merge pull request #2069 from huggingface/cleaner-pt-tf-conversion
clean up PT <=> TF conversion
parents
4b82c485
1d87b37d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
7 deletions
+15
-7
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+13
-6
transformers/modeling_utils.py
transformers/modeling_utils.py
+2
-1
No files found.
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
e6cff60b
...
@@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
pt_model
=
pt_model_class
.
from_pretrained
(
None
,
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'cpu'
)
pt_model
=
pt_model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
config
=
config
,
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
state_dict
=
state_dict
)
map_location
=
'cpu'
))
pt_inputs
=
torch
.
tensor
(
inputs_list
)
pt_inputs
=
torch
.
tensor
(
inputs_list
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pto
=
pt_model
(
pt_inputs
)
pto
=
pt_model
(
pt_inputs
)
...
@@ -139,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -139,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
model_shortcut_names_or_path
=
None
,
config_shortcut_names_or_path
=
None
,
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
model_shortcut_names_or_path
=
None
,
config_shortcut_names_or_path
=
None
,
compare_with_pt_model
=
False
,
use_cached_models
=
False
,
only_convert_finetuned_models
=
False
):
compare_with_pt_model
=
False
,
use_cached_models
=
False
,
remove_cached_files
=
False
,
only_convert_finetuned_models
=
False
):
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
if
args_model_type
is
None
:
if
args_model_type
is
None
:
...
@@ -187,13 +188,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
...
@@ -187,13 +188,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if
os
.
path
.
isfile
(
model_shortcut_name
):
if
os
.
path
.
isfile
(
model_shortcut_name
):
model_shortcut_name
=
'converted_model'
model_shortcut_name
=
'converted_model'
convert_pt_checkpoint_to_tf
(
model_type
=
model_type
,
convert_pt_checkpoint_to_tf
(
model_type
=
model_type
,
pytorch_checkpoint_path
=
model_file
,
pytorch_checkpoint_path
=
model_file
,
config_file
=
config_file
,
config_file
=
config_file
,
tf_dump_path
=
os
.
path
.
join
(
tf_dump_path
,
model_shortcut_name
+
'-tf_model.h5'
),
tf_dump_path
=
os
.
path
.
join
(
tf_dump_path
,
model_shortcut_name
+
'-tf_model.h5'
),
compare_with_pt_model
=
compare_with_pt_model
)
compare_with_pt_model
=
compare_with_pt_model
)
os
.
remove
(
config_file
)
if
remove_cached_files
:
os
.
remove
(
model_file
)
os
.
remove
(
config_file
)
os
.
remove
(
model_file
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -226,6 +229,9 @@ if __name__ == "__main__":
...
@@ -226,6 +229,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--use_cached_models"
,
parser
.
add_argument
(
"--use_cached_models"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"Use cached models if possible instead of updating to latest checkpoint versions."
)
help
=
"Use cached models if possible instead of updating to latest checkpoint versions."
)
parser
.
add_argument
(
"--remove_cached_files"
,
action
=
'store_true'
,
help
=
"Remove pytorch models after conversion (save memory when converting in batches)."
)
parser
.
add_argument
(
"--only_convert_finetuned_models"
,
parser
.
add_argument
(
"--only_convert_finetuned_models"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"Only convert finetuned models."
)
help
=
"Only convert finetuned models."
)
...
@@ -245,4 +251,5 @@ if __name__ == "__main__":
...
@@ -245,4 +251,5 @@ if __name__ == "__main__":
config_shortcut_names_or_path
=
[
args
.
config_file
]
if
args
.
config_file
is
not
None
else
None
,
config_shortcut_names_or_path
=
[
args
.
config_file
]
if
args
.
config_file
is
not
None
else
None
,
compare_with_pt_model
=
args
.
compare_with_pt_model
,
compare_with_pt_model
=
args
.
compare_with_pt_model
,
use_cached_models
=
args
.
use_cached_models
,
use_cached_models
=
args
.
use_cached_models
,
remove_cached_files
=
args
.
remove_cached_files
,
only_convert_finetuned_models
=
args
.
only_convert_finetuned_models
)
only_convert_finetuned_models
=
args
.
only_convert_finetuned_models
)
transformers/modeling_utils.py
View file @
e6cff60b
...
@@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
...
@@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
"""
if
"albert"
in
pretrained_model_name_or_path
and
"v2"
in
pretrained_model_name_or_path
:
if
pretrained_model_name_or_path
is
not
None
and
(
"albert"
in
pretrained_model_name_or_path
and
"v2"
in
pretrained_model_name_or_path
):
logger
.
warning
(
"There is currently an upstream reproducibility issue with ALBERT v2 models. Please see "
+
logger
.
warning
(
"There is currently an upstream reproducibility issue with ALBERT v2 models. Please see "
+
"https://github.com/google-research/google-research/issues/119 for more information."
)
"https://github.com/google-research/google-research/issues/119 for more information."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment