Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
271f2136
Commit
271f2136
authored
Sep 24, 2019
by
thomwolf
Browse files
updating to load tf model in pt - fixing headmasking test
parent
cf9c1cbb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
89 additions
and
69 deletions
+89
-69
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+16
-10
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+2
-2
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+68
-57
pytorch_transformers/tests/modeling_common_test.py
pytorch_transformers/tests/modeling_common_test.py
+3
-0
No files found.
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
271f2136
...
@@ -61,7 +61,10 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
...
@@ -61,7 +61,10 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
return
tf_name
,
transpose
return
tf_name
,
transpose
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
):
#####################
### PyTorch => TF 2.0
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""
try
:
try
:
...
@@ -77,18 +80,18 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
...
@@ -77,18 +80,18 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
):
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""
pt_state_dict
=
pt_model
.
state_dict
()
pt_state_dict
=
pt_model
.
state_dict
()
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
):
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch state_dict in a TF 2.0 model.
""" Load pytorch state_dict in a TF 2.0 model.
"""
"""
try
:
try
:
...
@@ -165,7 +168,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
...
@@ -165,7 +168,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
return
tf_model
return
tf_model
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
):
#####################
### TF 2.0 => PyTorch
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...
@@ -191,17 +197,17 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
...
@@ -191,17 +197,17 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
return
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
)
return
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
,
allow_missing_keys
=
allow_missing_keys
)
def
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
):
def
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
,
allow_missing_keys
=
False
):
""" Load TF 2.0 model in a pytorch model
""" Load TF 2.0 model in a pytorch model
"""
"""
weights
=
tf_model
.
weights
weights
=
tf_model
.
weights
return
load_tf2_weights_in_pytorch_model
(
pt_model
,
weights
)
return
load_tf2_weights_in_pytorch_model
(
pt_model
,
weights
,
allow_missing_keys
=
allow_missing_keys
)
def
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_weights
):
def
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_weights
,
allow_missing_keys
=
False
):
""" Load TF2.0 symbolic weights in a PyTorch model
""" Load TF2.0 symbolic weights in a PyTorch model
"""
"""
try
:
try
:
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
271f2136
...
@@ -129,7 +129,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -129,7 +129,7 @@ class TFPreTrainedModel(tf.keras.Model):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r
"""Instantiate a pretrained
pytorch
model from a pre-trained model configuration.
r
"""Instantiate a pretrained
TF 2.0
model from a pre-trained model configuration.
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with ``model.train()``
To train the model, you should first set it back in training mode with ``model.train()``
...
@@ -243,7 +243,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -243,7 +243,7 @@ class TFPreTrainedModel(tf.keras.Model):
if
from_pt
:
if
from_pt
:
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
return
cls
.
load_pt_weights
(
model
,
config
,
resolved_archive_file
)
return
cls
.
load_pt_weights
(
model
,
resolved_archive_file
)
inputs
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
inputs
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
ret
=
model
(
inputs
,
training
=
False
)
# build the network with dummy inputs
ret
=
model
(
inputs
,
training
=
False
)
# build the network with dummy inputs
...
...
pytorch_transformers/modeling_utils.py
View file @
271f2136
...
@@ -299,12 +299,12 @@ class PreTrainedModel(nn.Module):
...
@@ -299,12 +299,12 @@ class PreTrainedModel(nn.Module):
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
else
:
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
else
:
if
from_tf
:
assert
from_tf
,
"Error finding file {}, no file or TF 1.X checkpoint found"
.
format
(
pretrained_model_name_or_path
)
# Directly load from a TensorFlow checkpoint
archive_file
=
pretrained_model_name_or_path
+
".index"
archive_file
=
pretrained_model_name_or_path
+
".index"
else
:
archive_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
...
@@ -335,10 +335,25 @@ class PreTrainedModel(nn.Module):
...
@@ -335,10 +335,25 @@ class PreTrainedModel(nn.Module):
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
if
from_tf
:
if
resolved_archive_file
.
endswith
(
'.index'
):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
model
=
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
else
:
# Load from our TensorFlow 2.0 checkpoints
try
:
from
pytorch_transformers
import
load_tf2_checkpoint_in_pytorch_model
model
=
load_tf2_checkpoint_in_pytorch_model
(
model
,
resolved_archive_file
,
allow_missing_keys
=
True
)
except
ImportError
as
e
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
else
:
# Convert old format to new format if needed from a PyTorch state_dict
# Convert old format to new format if needed from a PyTorch state_dict
old_keys
=
[]
old_keys
=
[]
new_keys
=
[]
new_keys
=
[]
...
@@ -354,10 +369,6 @@ class PreTrainedModel(nn.Module):
...
@@ -354,10 +369,6 @@ class PreTrainedModel(nn.Module):
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
state_dict
=
state_dict
.
copy
()
...
...
pytorch_transformers/tests/modeling_common_test.py
View file @
271f2136
...
@@ -200,6 +200,9 @@ class CommonTestCases:
...
@@ -200,6 +200,9 @@ class CommonTestCases:
hidden_states
=
outputs
[
-
2
]
hidden_states
=
outputs
[
-
2
]
# Remove Nan
# Remove Nan
for
t
in
attentions
:
self
.
assertLess
(
torch
.
sum
(
torch
.
isnan
(
t
)),
t
.
numel
()
/
4
)
# Check we don't have more than 25% nans (arbitrary)
attentions
=
[
t
.
masked_fill
(
torch
.
isnan
(
t
),
0.0
)
for
t
in
attentions
]
# remove them (the test is less complete)
self
.
assertIsNotNone
(
multihead_outputs
)
self
.
assertIsNotNone
(
multihead_outputs
)
self
.
assertEqual
(
len
(
multihead_outputs
),
self
.
model_tester
.
num_hidden_layers
)
self
.
assertEqual
(
len
(
multihead_outputs
),
self
.
model_tester
.
num_hidden_layers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment