Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f19dad61
Commit
f19dad61
authored
Dec 12, 2019
by
thomwolf
Browse files
fixing XLM conversion tests with dummy input
parent
fafd4c86
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
3 deletions
+17
-3
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+5
-1
transformers/modeling_tf_xlm.py
transformers/modeling_tf_xlm.py
+1
-1
transformers/modeling_xlm.py
transformers/modeling_xlm.py
+11
-1
No files found.
transformers/modeling_tf_pytorch_utils.py
View file @
f19dad61
...
@@ -78,6 +78,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
...
@@ -78,6 +78,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
logger
.
info
(
"PyTorch checkpoint contains {:,} parameters"
.
format
(
sum
(
t
.
numel
()
for
t
in
pt_state_dict
.
values
())))
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
return
load_pytorch_weights_in_tf2_model
(
tf_model
,
pt_state_dict
,
tf_inputs
=
tf_inputs
,
allow_missing_keys
=
allow_missing_keys
)
...
@@ -134,7 +135,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -134,7 +135,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
start_prefix_to_remove
=
tf_model
.
base_model_prefix
+
'.'
start_prefix_to_remove
=
tf_model
.
base_model_prefix
+
'.'
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
tf_loaded_numel
=
0
weight_value_tuples
=
[]
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
pt_state_dict
.
keys
()))
all_pytorch_weights
=
set
(
list
(
pt_state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
for
symbolic_weight
in
symbolic_weights
:
...
@@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
raise
e
tf_loaded_numel
+=
array
.
size
# logger.warning("Initialize TF weight {}".format(symbolic_weight.name))
# logger.warning("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
...
@@ -169,6 +171,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -169,6 +171,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
logger
.
info
(
"Loaded {:,} parameters in the TF 2.0 model."
.
format
(
tf_loaded_numel
))
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
return
tf_model
...
...
transformers/modeling_tf_xlm.py
View file @
f19dad61
...
@@ -460,7 +460,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
...
@@ -460,7 +460,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
langs_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
langs_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
else
:
else
:
langs_list
=
None
langs_list
=
None
return
[
inputs_list
,
attns_list
,
langs_list
]
return
{
'input_ids'
:
inputs_list
,
'attention_mask'
:
attns_list
,
'langs'
:
langs_list
}
XLM_START_DOCSTRING
=
r
""" The XLM model was proposed in
XLM_START_DOCSTRING
=
r
""" The XLM model was proposed in
...
...
transformers/modeling_xlm.py
View file @
f19dad61
...
@@ -227,6 +227,16 @@ class XLMPreTrainedModel(PreTrainedModel):
...
@@ -227,6 +227,16 @@ class XLMPreTrainedModel(PreTrainedModel):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
@
property
def
dummy_inputs
(
self
):
inputs_list
=
torch
.
tensor
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
attns_list
=
torch
.
tensor
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
if
self
.
config
.
use_lang_emb
and
self
.
config
.
n_langs
>
1
:
langs_list
=
torch
.
tensor
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
else
:
langs_list
=
None
return
{
'input_ids'
:
inputs_list
,
'attention_mask'
:
attns_list
,
'langs'
:
langs_list
}
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
""" Initialize the weights. """
""" Initialize the weights. """
if
isinstance
(
module
,
nn
.
Embedding
):
if
isinstance
(
module
,
nn
.
Embedding
):
...
@@ -646,7 +656,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -646,7 +656,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs
=
langs
,
langs
=
langs
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
lengths
=
lengths
,
lengths
=
lengths
,
cache
=
cache
,
cache
=
cache
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment