Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
29bb3e4e
Commit
29bb3e4e
authored
Sep 24, 2019
by
thomwolf
Browse files
double loading ok
parent
f5397ffc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
5 deletions
+14
-5
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+14
-5
No files found.
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
29bb3e4e
...
...
@@ -25,6 +25,7 @@ import numpy
logger
=
logging
.
getLogger
(
__name__
)
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
def
convert_tf_weight_name_to_pt_weight_name
(
tf_name
,
start_prefix_to_remove
=
''
):
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
...
...
@@ -64,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
#####################
### PyTorch => TF 2.0
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
"""
try
:
...
...
@@ -83,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
"""
pt_state_dict
=
pt_model
.
state_dict
()
...
...
@@ -91,17 +92,21 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
""" Load pytorch state_dict in a TF 2.0 model.
"""
try
:
import
torch
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
as
e
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
if
tf_inputs
is
not
None
and
not
isinstance
(
tf_inputs
,
tf
.
Tensor
):
tf_inputs
=
tf
.
constant
(
tf_inputs
)
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
...
...
@@ -171,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
#####################
### TF 2.0 => PyTorch
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...
...
@@ -184,15 +189,19 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
import
pytorch_transformers
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
logger
.
info
(
"Loading TensorFlow weights from {}"
.
format
(
tf_checkpoint_path
))
# Instantiate and load the associated TF 2.0 model
tf_model_class_name
=
"TF"
+
model_class
.
__name__
# Add "TF" at the beggining
tf_model_class_name
=
"TF"
+
pt_
model
.
_
_class
__
.
__name__
# Add "TF" at the beggining
tf_model_class
=
getattr
(
pytorch_transformers
,
tf_model_class_name
)
tf_model
=
tf_model_class
(
pt_model
.
config
)
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
and
not
isinstance
(
tf_inputs
,
tf
.
Tensor
):
tf_inputs
=
tf
.
constant
(
tf_inputs
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment