Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3cb51299
Commit
3cb51299
authored
Dec 16, 2019
by
thomwolf
Committed by
Lysandre Debut
Dec 16, 2019
Browse files
Fix #2109
parent
18a879f4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
3 deletions
+12
-3
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+11
-2
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+1
-1
No files found.
transformers/modeling_tf_pytorch_utils.py
View file @
3cb51299
...
@@ -143,7 +143,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -143,7 +143,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
name
,
transpose
=
convert_tf_weight_name_to_pt_weight_name
(
sw_name
,
start_prefix_to_remove
=
start_prefix_to_remove
)
name
,
transpose
=
convert_tf_weight_name_to_pt_weight_name
(
sw_name
,
start_prefix_to_remove
=
start_prefix_to_remove
)
# Find associated numpy array in pytorch model state dict
# Find associated numpy array in pytorch model state dict
assert
name
in
pt_state_dict
,
"{} not found in PyTorch model"
.
format
(
name
)
if
name
not
in
pt_state_dict
:
if
allow_missing_keys
:
continue
raise
AttributeError
(
"{} not found in PyTorch model"
.
format
(
name
))
array
=
pt_state_dict
[
name
].
numpy
()
array
=
pt_state_dict
[
name
].
numpy
()
if
transpose
:
if
transpose
:
...
@@ -250,6 +254,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
...
@@ -250,6 +254,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
all_tf_weights
=
set
(
list
(
tf_weights_map
.
keys
()))
all_tf_weights
=
set
(
list
(
tf_weights_map
.
keys
()))
loaded_pt_weights_data_ptr
=
{}
loaded_pt_weights_data_ptr
=
{}
missing_keys_pt
=
[]
for
pt_weight_name
,
pt_weight
in
current_pt_params_dict
.
items
():
for
pt_weight_name
,
pt_weight
in
current_pt_params_dict
.
items
():
# Handle PyTorch shared weight ()not duplicated in TF 2.0
# Handle PyTorch shared weight ()not duplicated in TF 2.0
if
pt_weight
.
data_ptr
()
in
loaded_pt_weights_data_ptr
:
if
pt_weight
.
data_ptr
()
in
loaded_pt_weights_data_ptr
:
...
@@ -258,7 +263,10 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
...
@@ -258,7 +263,10 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
# Find associated numpy array in pytorch model state dict
# Find associated numpy array in pytorch model state dict
if
pt_weight_name
not
in
tf_weights_map
:
if
pt_weight_name
not
in
tf_weights_map
:
raise
ValueError
(
"{} not found in TF 2.0 model"
.
format
(
pt_weight_name
))
if
allow_missing_keys
:
missing_keys_pt
.
append
(
pt_weight_name
)
continue
raise
AttributeError
(
"{} not found in TF 2.0 model"
.
format
(
pt_weight_name
))
array
,
transpose
=
tf_weights_map
[
pt_weight_name
]
array
,
transpose
=
tf_weights_map
[
pt_weight_name
]
...
@@ -283,6 +291,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
...
@@ -283,6 +291,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
all_tf_weights
.
discard
(
pt_weight_name
)
all_tf_weights
.
discard
(
pt_weight_name
)
missing_keys
,
unexpected_keys
=
pt_model
.
load_state_dict
(
new_pt_params_dict
,
strict
=
False
)
missing_keys
,
unexpected_keys
=
pt_model
.
load_state_dict
(
new_pt_params_dict
,
strict
=
False
)
missing_keys
+=
missing_keys_pt
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from TF 2.0 model: {}"
.
format
(
logger
.
info
(
"Weights of {} not initialized from TF 2.0 model: {}"
.
format
(
...
...
transformers/modeling_tf_utils.py
View file @
3cb51299
...
@@ -297,7 +297,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -297,7 +297,7 @@ class TFPreTrainedModel(tf.keras.Model):
if
from_pt
:
if
from_pt
:
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
return
load_pytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
)
return
load_pytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
,
allow_missing_keys
=
True
)
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment