Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
969d3ae9
Commit
969d3ae9
authored
Sep 11, 2019
by
thomwolf
Browse files
XLMWithLMHead fixed - standardize conversion
parent
646711e1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
20 additions
and
21 deletions
+20
-21
pytorch_transformers/modeling_tf_bert.py
pytorch_transformers/modeling_tf_bert.py
+1
-1
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+1
-1
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+11
-10
pytorch_transformers/modeling_tf_xlm.py
pytorch_transformers/modeling_tf_xlm.py
+4
-6
pytorch_transformers/modeling_tf_xlnet.py
pytorch_transformers/modeling_tf_xlnet.py
+1
-1
pytorch_transformers/modeling_xlm.py
pytorch_transformers/modeling_xlm.py
+2
-2
No files found.
pytorch_transformers/modeling_tf_bert.py
View file @
969d3ae9
...
@@ -57,7 +57,7 @@ def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
...
@@ -57,7 +57,7 @@ def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
def
gelu
(
x
):
...
...
pytorch_transformers/modeling_tf_gpt2.py
View file @
969d3ae9
...
@@ -46,7 +46,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
...
@@ -46,7 +46,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
def
gelu
(
x
):
...
...
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
969d3ae9
...
@@ -19,34 +19,34 @@ from __future__ import (absolute_import, division, print_function,
...
@@ -19,34 +19,34 @@ from __future__ import (absolute_import, division, print_function,
unicode_literals
)
unicode_literals
)
import
logging
import
logging
import
os
from
pytorch_transformers
import
is_tf_available
,
is_torch_available
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
):
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
None
):
""" Load pytorch checkpoints in a TF 2.0 model
""" Load pytorch checkpoints in a TF 2.0 model
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
"""
if
not
is_tf_available
()
or
not
is_torch_available
():
try
:
import
tensorflow
as
tf
import
torch
except
ImportError
as
e
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
ImportError
raise
e
import
torch
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
return
load_pytorch_state_dict_in_tf2_model
(
tf_model
,
pt_state_dict
)
return
load_pytorch_state_dict_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
)
def
load_pytorch_state_dict_in_tf2_model
(
tf_model
,
pt_state_dict
):
def
load_pytorch_state_dict_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
None
):
""" Load pytorch state_dict in a TF 2.0 model.
""" Load pytorch state_dict in a TF 2.0 model.
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
...
@@ -102,7 +102,8 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
...
@@ -102,7 +102,8 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
K
.
batch_set_value
(
weight_value_tuples
)
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
...
...
pytorch_transformers/modeling_tf_xlm.py
View file @
969d3ae9
...
@@ -50,11 +50,9 @@ def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
...
@@ -50,11 +50,9 @@ def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
attns_list
=
[[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]]
attns_list
=
[[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]]
langs_list
=
[[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]]
langs_list
=
[[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
[
tf
.
constant
(
inputs_list
),
tf
.
constant
(
attns_list
),
tf
.
constant
(
langs_list
)]
tf_attns
=
tf
.
constant
(
attns_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
tf_langs
=
tf
.
constant
(
langs_list
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
tfo
=
tf_model
([
tf_inputs
,
tf_attns
,
tf_langs
],
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
...
@@ -614,7 +612,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
...
@@ -614,7 +612,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFXLMWithLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TFXLMWithLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer
___
'
)
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer'
)
self
.
pred_layer
=
TFXLMPredLayer
(
config
,
self
.
transformer
.
embeddings
,
name
=
'pred_layer_._proj'
)
self
.
pred_layer
=
TFXLMPredLayer
(
config
,
self
.
transformer
.
embeddings
,
name
=
'pred_layer_._proj'
)
...
...
pytorch_transformers/modeling_tf_xlnet.py
View file @
969d3ae9
...
@@ -45,7 +45,7 @@ def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
...
@@ -45,7 +45,7 @@ def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
def
gelu
(
x
):
...
...
pytorch_transformers/modeling_xlm.py
View file @
969d3ae9
...
@@ -563,10 +563,10 @@ class XLMPredLayer(nn.Module):
...
@@ -563,10 +563,10 @@ class XLMPredLayer(nn.Module):
"""
"""
outputs
=
()
outputs
=
()
if
self
.
asm
is
False
:
if
self
.
asm
is
False
:
scores
=
self
.
proj
(
x
)
.
view
(
-
1
,
self
.
n_words
)
scores
=
self
.
proj
(
x
)
outputs
=
(
scores
,)
+
outputs
outputs
=
(
scores
,)
+
outputs
if
y
is
not
None
:
if
y
is
not
None
:
loss
=
F
.
cross_entropy
(
scores
,
y
,
reduction
=
'elementwise_mean'
)
loss
=
F
.
cross_entropy
(
scores
.
view
(
-
1
,
self
.
n_words
)
,
y
,
reduction
=
'elementwise_mean'
)
outputs
=
(
loss
,)
+
outputs
outputs
=
(
loss
,)
+
outputs
else
:
else
:
scores
=
self
.
proj
.
log_prob
(
x
)
scores
=
self
.
proj
.
log_prob
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment