Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f19dad61
Commit
f19dad61
authored
Dec 12, 2019
by
thomwolf
Browse files
fixing XLM conversion tests with dummy input
parent
fafd4c86
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
3 deletions
+17
-3
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+5
-1
transformers/modeling_tf_xlm.py
transformers/modeling_tf_xlm.py
+1
-1
transformers/modeling_xlm.py
transformers/modeling_xlm.py
+11
-1
No files found.
transformers/modeling_tf_pytorch_utils.py
View file @
f19dad61
...
@@ -78,6 +78,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
...
@@ -78,6 +78,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
logger
.
info
(
"PyTorch checkpoint contains {:,} parameters"
.
format
(
sum
(
t
.
numel
()
for
t
in
pt_state_dict
.
values
())))
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
...
@@ -134,7 +135,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -134,7 +135,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
start_prefix_to_remove
=
tf_model
.
base_model_prefix
+
'.'
start_prefix_to_remove
=
tf_model
.
base_model_prefix
+
'.'
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
tf_loaded_numel
=
0
weight_value_tuples
=
[]
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
pt_state_dict
.
keys
()))
all_pytorch_weights
=
set
(
list
(
pt_state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
for
symbolic_weight
in
symbolic_weights
:
...
@@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
raise
e
tf_loaded_numel
+=
array
.
size
# logger.warning("Initialize TF weight {}".format(symbolic_weight.name))
# logger.warning("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
...
@@ -169,6 +171,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -169,6 +171,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
logger
.
info
(
"Loaded {:,} parameters in the TF 2.0 model."
.
format
(
tf_loaded_numel
))
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
return
tf_model
...
...
transformers/modeling_tf_xlm.py
View file @
f19dad61
...
@@ -460,7 +460,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
...
@@ -460,7 +460,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
langs_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
langs_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
else
:
else
:
langs_list
=
None
langs_list
=
None
return
[
inputs_list
,
attns_list
,
langs_list
]
return
{
'input_ids'
:
inputs_list
,
'attention_mask'
:
attns_list
,
'langs'
:
langs_list
}
XLM_START_DOCSTRING
=
r
""" The XLM model was proposed in
XLM_START_DOCSTRING
=
r
""" The XLM model was proposed in
...
...
transformers/modeling_xlm.py
View file @
f19dad61
...
@@ -227,6 +227,16 @@ class XLMPreTrainedModel(PreTrainedModel):
...
@@ -227,6 +227,16 @@ class XLMPreTrainedModel(PreTrainedModel):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
@
property
def
dummy_inputs
(
self
):
inputs_list
=
torch
.
tensor
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
attns_list
=
torch
.
tensor
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
if
self
.
config
.
use_lang_emb
and
self
.
config
.
n_langs
>
1
:
langs_list
=
torch
.
tensor
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
else
:
langs_list
=
None
return
{
'input_ids'
:
inputs_list
,
'attention_mask'
:
attns_list
,
'langs'
:
langs_list
}
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
""" Initialize the weights. """
""" Initialize the weights. """
if
isinstance
(
module
,
nn
.
Embedding
):
if
isinstance
(
module
,
nn
.
Embedding
):
...
@@ -646,7 +656,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -646,7 +656,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs
=
langs
,
langs
=
langs
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
lengths
=
lengths
,
lengths
=
lengths
,
cache
=
cache
,
cache
=
cache
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment