Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
271f2136
Commit
271f2136
authored
Sep 24, 2019
by
thomwolf
Browse files
updating to load tf model in pt - fixing headmasking test
parent
cf9c1cbb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
89 additions
and
69 deletions
+89
-69
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+16
-10
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+2
-2
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+68
-57
pytorch_transformers/tests/modeling_common_test.py
pytorch_transformers/tests/modeling_common_test.py
+3
-0
No files found.
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
271f2136
...
@@ -61,7 +61,10 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
...
@@ -61,7 +61,10 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove='')
return
tf_name
,
transpose
return
tf_name
,
transpose
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
):
#####################
### PyTorch => TF 2.0
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""
try
:
try
:
...
@@ -77,18 +80,18 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
...
@@ -77,18 +80,18 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
):
def
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""
pt_state_dict
=
pt_model
.
state_dict
()
pt_state_dict
=
pt_model
.
state_dict
()
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
):
def
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load pytorch state_dict in a TF 2.0 model.
""" Load pytorch state_dict in a TF 2.0 model.
"""
"""
try
:
try
:
...
@@ -165,7 +168,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
...
@@ -165,7 +168,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
return
tf_model
return
tf_model
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
):
#####################
### TF 2.0 => PyTorch
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
,
tf_inputs
=
None
,
allow_missing_keys
=
False
):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
...
@@ -191,17 +197,17 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
...
@@ -191,17 +197,17 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
tf_model
.
load_weights
(
tf_checkpoint_path
,
by_name
=
True
)
return
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
)
return
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
,
allow_missing_keys
=
allow_missing_keys
)
def
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
):
def
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
,
allow_missing_keys
=
False
):
""" Load TF 2.0 model in a pytorch model
""" Load TF 2.0 model in a pytorch model
"""
"""
weights
=
tf_model
.
weights
weights
=
tf_model
.
weights
return
load_tf2_weights_in_pytorch_model
(
pt_model
,
weights
)
return
load_tf2_weights_in_pytorch_model
(
pt_model
,
weights
,
allow_missing_keys
=
allow_missing_keys
)
def
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_weights
):
def
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_weights
,
allow_missing_keys
=
False
):
""" Load TF2.0 symbolic weights in a PyTorch model
""" Load TF2.0 symbolic weights in a PyTorch model
"""
"""
try
:
try
:
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
271f2136
...
@@ -129,7 +129,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -129,7 +129,7 @@ class TFPreTrainedModel(tf.keras.Model):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
r
"""Instantiate a pretrained
pytorch
model from a pre-trained model configuration.
r
"""Instantiate a pretrained
TF 2.0
model from a pre-trained model configuration.
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
To train the model, you should first set it back in training mode with ``model.train()``
To train the model, you should first set it back in training mode with ``model.train()``
...
@@ -243,7 +243,7 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -243,7 +243,7 @@ class TFPreTrainedModel(tf.keras.Model):
if
from_pt
:
if
from_pt
:
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
return
cls
.
load_pt_weights
(
model
,
config
,
resolved_archive_file
)
return
cls
.
load_pt_weights
(
model
,
resolved_archive_file
)
inputs
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
inputs
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
ret
=
model
(
inputs
,
training
=
False
)
# build the network with dummy inputs
ret
=
model
(
inputs
,
training
=
False
)
# build the network with dummy inputs
...
...
pytorch_transformers/modeling_utils.py
View file @
271f2136
...
@@ -299,12 +299,12 @@ class PreTrainedModel(nn.Module):
...
@@ -299,12 +299,12 @@ class PreTrainedModel(nn.Module):
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
else
:
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
else
:
if
from_tf
:
assert
from_tf
,
"Error finding file {}, no file or TF 1.X checkpoint found"
.
format
(
pretrained_model_name_or_path
)
# Directly load from a TensorFlow checkpoint
archive_file
=
pretrained_model_name_or_path
+
".index"
archive_file
=
pretrained_model_name_or_path
+
".index"
else
:
archive_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
# redirect to the cache, if necessary
try
:
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
)
...
@@ -335,61 +335,72 @@ class PreTrainedModel(nn.Module):
...
@@ -335,61 +335,72 @@ class PreTrainedModel(nn.Module):
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
# Convert old format to new format if needed from a PyTorch state_dict
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
if
'beta'
in
key
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
# Load from a PyTorch state_dict
missing_keys
=
[]
missing_keys
=
[]
unexpected_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
if
from_tf
:
state_dict
=
state_dict
.
copy
()
if
resolved_archive_file
.
endswith
(
'.index'
):
if
metadata
is
not
None
:
# Load from a TensorFlow 1.X checkpoint - provided by original authors
state_dict
.
_metadata
=
metadata
model
=
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
else
:
def
load
(
module
,
prefix
=
''
):
# Load from our TensorFlow 2.0 checkpoints
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
try
:
module
.
_load_from_state_dict
(
from
pytorch_transformers
import
load_tf2_checkpoint_in_pytorch_model
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
model
=
load_tf2_checkpoint_in_pytorch_model
(
model
,
resolved_archive_file
,
allow_missing_keys
=
True
)
for
name
,
child
in
module
.
_modules
.
items
():
except
ImportError
as
e
:
if
child
is
not
None
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
load
(
child
,
prefix
+
name
+
'.'
)
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
# Make sure we are able to load base models as well as derived models (with heads)
else
:
start_prefix
=
''
# Convert old format to new format if needed from a PyTorch state_dict
model_to_load
=
model
old_keys
=
[]
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
new_keys
=
[]
start_prefix
=
cls
.
base_model_prefix
+
'.'
for
key
in
state_dict
.
keys
():
if
hasattr
(
model
,
cls
.
base_model_prefix
)
and
not
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
new_key
=
None
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
load
(
model_to_load
,
prefix
=
start_prefix
)
if
'beta'
in
key
:
if
len
(
missing_keys
)
>
0
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
if
new_key
:
model
.
__class__
.
__name__
,
missing_keys
))
old_keys
.
append
(
key
)
if
len
(
unexpected_keys
)
>
0
:
new_keys
.
append
(
new_key
)
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
model
.
__class__
.
__name__
,
unexpected_keys
))
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
# copy state_dict so _load_from_state_dict can modify it
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix
=
''
model_to_load
=
model
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
cls
.
base_model_prefix
+
'.'
if
hasattr
(
model
,
cls
.
base_model_prefix
)
and
not
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
load
(
model_to_load
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
hasattr
(
model
,
'tie_weights'
):
if
hasattr
(
model
,
'tie_weights'
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
model
.
tie_weights
()
# make sure word embedding weights are still tied
...
...
pytorch_transformers/tests/modeling_common_test.py
View file @
271f2136
...
@@ -200,6 +200,9 @@ class CommonTestCases:
...
@@ -200,6 +200,9 @@ class CommonTestCases:
hidden_states
=
outputs
[
-
2
]
hidden_states
=
outputs
[
-
2
]
# Remove Nan
# Remove Nan
for
t
in
attentions
:
self
.
assertLess
(
torch
.
sum
(
torch
.
isnan
(
t
)),
t
.
numel
()
/
4
)
# Check we don't have more than 25% nans (arbitrary)
attentions
=
[
t
.
masked_fill
(
torch
.
isnan
(
t
),
0.0
)
for
t
in
attentions
]
# remove them (the test is less complete)
self
.
assertIsNotNone
(
multihead_outputs
)
self
.
assertIsNotNone
(
multihead_outputs
)
self
.
assertEqual
(
len
(
multihead_outputs
),
self
.
model_tester
.
num_hidden_layers
)
self
.
assertEqual
(
len
(
multihead_outputs
),
self
.
model_tester
.
num_hidden_layers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment