Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3cb51299
Commit
3cb51299
authored
Dec 16, 2019
by
thomwolf
Committed by
Lysandre Debut
Dec 16, 2019
Browse files
Fix #2109
parent
18a879f4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
3 deletions
+12
-3
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+11
-2
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+1
-1
No files found.
transformers/modeling_tf_pytorch_utils.py
View file @
3cb51299
...
...
@@ -143,7 +143,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
name
,
transpose
=
convert_tf_weight_name_to_pt_weight_name
(
sw_name
,
start_prefix_to_remove
=
start_prefix_to_remove
)
# Find associated numpy array in pytorch model state dict
assert
name
in
pt_state_dict
,
"{} not found in PyTorch model"
.
format
(
name
)
if
name
not
in
pt_state_dict
:
if
allow_missing_keys
:
continue
raise
AttributeError
(
"{} not found in PyTorch model"
.
format
(
name
))
array
=
pt_state_dict
[
name
].
numpy
()
if
transpose
:
...
...
@@ -250,6 +254,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
all_tf_weights
=
set
(
list
(
tf_weights_map
.
keys
()))
loaded_pt_weights_data_ptr
=
{}
missing_keys_pt
=
[]
for
pt_weight_name
,
pt_weight
in
current_pt_params_dict
.
items
():
# Handle PyTorch shared weight ()not duplicated in TF 2.0
if
pt_weight
.
data_ptr
()
in
loaded_pt_weights_data_ptr
:
...
...
@@ -258,7 +263,10 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
# Find associated numpy array in pytorch model state dict
if
pt_weight_name
not
in
tf_weights_map
:
raise
ValueError
(
"{} not found in TF 2.0 model"
.
format
(
pt_weight_name
))
if
allow_missing_keys
:
missing_keys_pt
.
append
(
pt_weight_name
)
continue
raise
AttributeError
(
"{} not found in TF 2.0 model"
.
format
(
pt_weight_name
))
array
,
transpose
=
tf_weights_map
[
pt_weight_name
]
...
...
@@ -283,6 +291,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
all_tf_weights
.
discard
(
pt_weight_name
)
missing_keys
,
unexpected_keys
=
pt_model
.
load_state_dict
(
new_pt_params_dict
,
strict
=
False
)
missing_keys
+=
missing_keys_pt
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from TF 2.0 model: {}"
.
format
(
...
...
transformers/modeling_tf_utils.py
View file @
3cb51299
...
...
@@ -297,7 +297,7 @@ class TFPreTrainedModel(tf.keras.Model):
if
from_pt
:
# Load from a PyTorch checkpoint
return
load_pytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
)
return
load_pytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
,
allow_missing_keys
=
True
)
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment