Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
29bb3e4e
Commit
29bb3e4e
authored
Sep 24, 2019
by
thomwolf
Browse files
double loading ok
parent
f5397ffc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
5 deletions
+14
-5
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+14
-5
No files found.
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
29bb3e4e
...
@@ -25,6 +25,7 @@ import numpy
...
@@ -25,6 +25,7 @@ import numpy
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
def
convert_tf_weight_name_to_pt_weight_name
(
tf_name
,
start_prefix_to_remove
=
''
):
def
convert_tf_weight_name_to_pt_weight_name
(
tf_name
,
start_prefix_to_remove
=
''
):
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
...
@@ -64,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
...
@@ -64,7 +65,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
#####################
#####################
### PyTorch => TF 2.0
### PyTorch => TF 2.0
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""
try
:
try
:
...
@@ -83,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
...
@@ -83,7 +84,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""
pt_state_dict
=
pt_model
.
state_dict
()
pt_state_dict
=
pt_model
.
state_dict
()
...
@@ -91,17 +92,21 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
...
@@ -91,17 +92,21 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
""" Load pytorch state_dict in a TF 2.0 model.
""" Load pytorch state_dict in a TF 2.0 model.
"""
"""
try
:
try
:
import
torch
import
torch
import
tensorflow
as
tf
from
tensorflow.python.keras
import
backend
as
K
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
raise
e
if
tf_inputs
is
not
None
and
not
isinstance
(
tf_inputs
,
tf
.
Tensor
):
tf_inputs
=
tf
.
constant
(
tf_inputs
)
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
...
@@ -171,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -171,7 +176,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
#####################
#####################
### TF 2.0 => PyTorch
### TF 2.0 => PyTorch
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
DUMMY_INPUTS
,
allow_missing_keys
=
False
):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...
@@ -184,15 +189,19 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
...
@@ -184,15 +189,19 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
raise
e
import
pytorch_transformers
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
logger
.
info
(
"Loading TensorFlow weights from {}"
.
format
(
tf_checkpoint_path
))
logger
.
info
(
"Loading TensorFlow weights from {}"
.
format
(
tf_checkpoint_path
))
# Instantiate and load the associated TF 2.0 model
# Instantiate and load the associated TF 2.0 model
tf_model_class_name
=
"TF"
+
model_class
.
__name__
# Add "TF" at the beggining
tf_model_class_name
=
"TF"
+
pt_
model
.
_
_class
__
.
__name__
# Add "TF" at the beggining
tf_model_class
=
getattr
(
pytorch_transformers
,
tf_model_class_name
)
tf_model_class
=
getattr
(
pytorch_transformers
,
tf_model_class_name
)
tf_model
=
tf_model_class
(
pt_model
.
config
)
tf_model
=
tf_model_class
(
pt_model
.
config
)
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
and
not
isinstance
(
tf_inputs
,
tf
.
Tensor
):
tf_inputs
=
tf
.
constant
(
tf_inputs
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment