Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
128bdd4c
Commit
128bdd4c
authored
Sep 24, 2019
by
thomwolf
Browse files
fix tests pt/tf
parent
28a30af6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
8 deletions
+9
-8
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+9
-8
No files found.
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
128bdd4c
...
...
@@ -65,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
#####################
### PyTorch => TF 2.0
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
"""
try
:
...
...
@@ -84,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
"""
pt_state_dict
=
pt_model
.
state_dict
()
...
...
@@ -92,7 +92,7 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=DUMMY_INPUTS,
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch state_dict in a TF 2.0 model.
"""
try
:
...
...
@@ -104,8 +104,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_I
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
if
tf_inputs
is
not
None
and
not
isinstance
(
tf_inputs
,
tf
.
Tensor
)
:
tf_inputs
=
tf
.
constant
(
tf_inputs
)
if
tf_inputs
is
None
:
tf_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
...
...
@@ -176,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=DUMMY_I
#####################
### TF 2.0 => PyTorch
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...
...
@@ -199,9 +199,10 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
tf_model_class
=
getattr
(
pytorch_transformers
,
tf_model_class_name
)
tf_model
=
tf_model_class
(
pt_model
.
config
)
if
tf_inputs
is
None
:
tf_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
and
not
isinstance
(
tf_inputs
,
tf
.
Tensor
):
tf_inputs
=
tf
.
constant
(
tf_inputs
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment