Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
271f2136
Commit
271f2136
authored
Sep 24, 2019
by
thomwolf
Browse files
updating to load tf model in pt - fixing headmasking test
parent
cf9c1cbb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
89 additions
and
69 deletions
+89
-69
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+16
-10
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+2
-2
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+68
-57
pytorch_transformers/tests/modeling_common_test.py
pytorch_transformers/tests/modeling_common_test.py
+3
-0
No files found.
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
271f2136
...
@@ -61,7 +61,10 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
...
@@ -61,7 +61,10 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
return
tf_name
,
transpose
return
tf_name
,
transpose
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
):
#####################
### PyTorch => TF 2.0
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""
try
:
try
:
...
@@ -77,18 +80,18 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
...
@@ -77,18 +80,18 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
):
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""
pt_state_dict
=
pt_model
.
state_dict
()
pt_state_dict
=
pt_model
.
state_dict
()
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
):
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch state_dict in a TF 2.0 model.
""" Load pytorch state_dict in a TF 2.0 model.
"""
"""
try
:
try
:
...
@@ -165,7 +168,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
...
@@ -165,7 +168,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
return
tf_model
return
tf_model
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
):
#####################
### TF 2.0 => PyTorch
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...
@@ -191,17 +197,17 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
...
@@ -191,17 +197,17 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
return
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
)
return
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
,
allow_missing_keys
=
allow_missing_keys
)
def
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
):
def
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
,
allow_missing_keys
=
False
):
""" Load TF 2.0 model in a pytorch model
""" Load TF 2.0 model in a pytorch model
"""
"""
weights
=
tf_model
.
weights
weights
=
tf_model
.
weights
return
load_tf2_weights_in_pytorch_model
(
pt_model
,
weights
)
return
load_tf2_weights_in_pytorch_model
(
pt_model
,
weights
,
allow_missing_keys
=
allow_missing_keys
)
def
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_weights
):
def
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_weights
,
allow_missing_keys
=
False
):
""" Load TF2.0 symbolic weights in a PyTorch model
""" Load TF2.0 symbolic weights in a PyTorch model
"""
"""
try
:
try
:
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
271f2136
...
@@ -129,7 +129,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -129,7 +129,7 @@ class TFPreTrainedModel(tf.keras.Model):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r
"""Instantiate a pretrained
pytorch
model from a pre-trained model configuration.
r
"""Instantiate a pretrained
TF 2.0
model from a pre-trained model configuration.
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with ``model.train()``
To train the model, you should first set it back in training mode with ``model.train()``
...
@@ -243,7 +243,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -243,7 +243,7 @@ class TFPreTrainedModel(tf.keras.Model):
if
from_pt
:
if
from_pt
:
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
return
cls
.
load_pt_weights
(
model
,
config
,
resolved_archive_file
)
return
cls
.
load_pt_weights
(
model
,
resolved_archive_file
)
inputs
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
inputs
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
ret
=
model
(
inputs
,
training
=
False
)
# build the network with dummy inputs
ret
=
model
(
inputs
,
training
=
False
)
# build the network with dummy inputs
...
...
pytorch_transformers/modeling_utils.py
View file @
271f2136
...
@@ -299,12 +299,12 @@ class PreTrainedModel(nn.Module):
...
@@ -299,12 +299,12 @@ class PreTrainedModel(nn.Module):
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
else
:
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
else
:
if
from_tf
:
assert
from_tf
,
"Error finding file {}, no file or TF 1.X checkpoint found"
.
format
(
pretrained_model_name_or_path
)
# Directly load from a TensorFlow checkpoint
archive_file
=
pretrained_model_name_or_path
+
".index"
archive_file
=
pretrained_model_name_or_path
+
".index"
else
:
archive_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
...
@@ -335,10 +335,25 @@ class PreTrainedModel(nn.Module):
...
@@ -335,10 +335,25 @@ class PreTrainedModel(nn.Module):
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
if
from_tf
:
if
resolved_archive_file
.
endswith
(
'.index'
):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
model
=
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
else
:
# Load from our TensorFlow 2.0 checkpoints
try
:
from
pytorch_transformers
import
load_tf2_checkpoint_in_pytorch_model
model
=
load_tf2_checkpoint_in_pytorch_model
(
model
,
resolved_archive_file
,
allow_missing_keys
=
True
)
except
ImportError
as
e
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
else
:
# Convert old format to new format if needed from a PyTorch state_dict
# Convert old format to new format if needed from a PyTorch state_dict
old_keys
=
[]
old_keys
=
[]
new_keys
=
[]
new_keys
=
[]
...
@@ -354,10 +369,6 @@ class PreTrainedModel(nn.Module):
...
@@ -354,10 +369,6 @@ class PreTrainedModel(nn.Module):
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
state_dict
=
state_dict
.
copy
()
...
...
pytorch_transformers/tests/modeling_common_test.py
View file @
271f2136
...
@@ -200,6 +200,9 @@ class CommonTestCases:
...
@@ -200,6 +200,9 @@ class CommonTestCases:
hidden_states
=
outputs
[
-
2
]
hidden_states
=
outputs
[
-
2
]
# Remove Nan
# Remove Nan
for
t
in
attentions
:
self
.
assertLess
(
torch
.
sum
(
torch
.
isnan
(
t
)),
t
.
numel
()
/
4
)
# Check we don't have more than 25% nans (arbitrary)
attentions
=
[
t
.
masked_fill
(
torch
.
isnan
(
t
),
0.0
)
for
t
in
attentions
]
# remove them (the test is less complete)
self
.
assertIsNotNone
(
multihead_outputs
)
self
.
assertIsNotNone
(
multihead_outputs
)
self
.
assertEqual
(
len
(
multihead_outputs
),
self
.
model_tester
.
num_hidden_layers
)
self
.
assertEqual
(
len
(
multihead_outputs
),
self
.
model_tester
.
num_hidden_layers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment