Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7de17404
Commit
7de17404
authored
Jun 25, 2019
by
thomwolf
Browse files
add ability to restore fine-tuned TF mdoel
parent
7334bf6c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
15 deletions
+33
-15
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+23
-7
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+10
-8
No files found.
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
7de17404
...
@@ -24,16 +24,27 @@ import torch
...
@@ -24,16 +24,27 @@ import torch
from
pytorch_pretrained_bert.modeling_xlnet
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
from
pytorch_pretrained_bert.modeling_xlnet
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
XLNetConfig
,
XLNetRunConfig
,
XLNetConfig
,
XLNetRunConfig
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
)
XLNetLMHeadModel
,
XLNetForQuestionAnswering
,
XLNetForSequenceClassification
,
load_tf_weights_in_xlnet
)
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_folder_path
):
GLUE_TASKS
=
[
"cola"
,
"mnli"
,
"mnli-mm"
,
"mrpc"
,
"sst-2"
,
"sts-b"
,
"qqp"
,
"qnli"
,
"rte"
,
"wnli"
]
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_folder_path
,
finetuning_task
=
None
):
# Initialise PyTorch model
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
if
finetuning_task
is
not
None
and
finetuning_task
.
lower
()
in
GLUE_TASKS
:
model
=
XLNetLMHeadModel
(
config
)
model_class
=
XLNetLMHeadModel
elif
finetuning_task
is
not
None
and
'squad'
in
finetuning_task
.
lower
():
model_class
=
XLNetForQuestionAnswering
else
:
model_class
=
XLNetLMHeadModel
print
(
"Building PyTorch model {} from configuration: {}"
.
format
(
str
(
model_class
),
str
(
config
)))
model
=
model_class
(
config
)
# Load weights from tf checkpoint
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
)
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
,
finetuning_task
)
# Save pytorch-model
# Save pytorch-model
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
...
@@ -59,12 +70,17 @@ if __name__ == "__main__":
...
@@ -59,12 +70,17 @@ if __name__ == "__main__":
required
=
True
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
finetuning_task
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
)
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
)
parser
.
add_argument
(
"--finetuning_task"
,
default
=
None
,
type
=
str
,
help
=
"Name of a task on which the XLNet TensorFloaw model was fine-tuned"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_folder_path
)
args
.
pytorch_dump_folder_path
,
args
.
finetuning_task
)
pytorch_pretrained_bert/modeling_xlnet.py
View file @
7de17404
...
@@ -46,7 +46,7 @@ XLNET_CONFIG_NAME = 'xlnet_config.json'
...
@@ -46,7 +46,7 @@ XLNET_CONFIG_NAME = 'xlnet_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
):
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
,
finetuning_task
=
None
):
""" A map of modules from TF to PyTorch.
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
identical to the original PyTorch model as possible.
...
@@ -62,8 +62,10 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
...
@@ -62,8 +62,10 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
# We will load also the sequence summary
# We will load also the sequence summary
tf_to_pt_map
[
'model/sequnece_summary/summary/kernel'
]
=
model
.
sequence_summary
.
summary
.
weight
tf_to_pt_map
[
'model/sequnece_summary/summary/kernel'
]
=
model
.
sequence_summary
.
summary
.
weight
tf_to_pt_map
[
'model/sequnece_summary/summary/bias'
]
=
model
.
sequence_summary
.
summary
.
bias
tf_to_pt_map
[
'model/sequnece_summary/summary/bias'
]
=
model
.
sequence_summary
.
summary
.
bias
elif
hasattr
(
model
,
'proj_loss'
)
and
any
(
'model/regression'
in
name
for
name
in
tf_weights
.
keys
()):
elif
hasattr
(
model
,
'logits_proj'
)
and
finetuning_task
is
not
None
and
any
(
'model/regression'
in
name
for
name
in
tf_weights
.
keys
()):
raise
NotImplementedError
tf_to_pt_map
[
'model/regression_{}/logit/kernel'
.
format
(
finetuning_task
)]
=
model
.
logits_proj
.
weight
tf_to_pt_map
[
'model/regression_{}/logit/bias'
.
format
(
finetuning_task
)]
=
model
.
logits_proj
.
bias
# Now load the rest of the transformer
# Now load the rest of the transformer
model
=
model
.
transformer
model
=
model
.
transformer
...
@@ -113,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
...
@@ -113,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
'model/transformer/seg_embed'
:
seg_embed_list
})
'model/transformer/seg_embed'
:
seg_embed_list
})
return
tf_to_pt_map
return
tf_to_pt_map
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
):
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
,
finetuning_task
=
None
):
""" Load tf checkpoints in a pytorch model
""" Load tf checkpoints in a pytorch model
"""
"""
try
:
try
:
...
@@ -132,7 +134,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
...
@@ -132,7 +134,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
tf_weights
[
name
]
=
array
tf_weights
[
name
]
=
array
# Build TF to PyTorch weights loading map
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
)
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
,
finetuning_task
)
for
name
,
pointer
in
tf_to_pt_map
.
items
():
for
name
,
pointer
in
tf_to_pt_map
.
items
():
print
(
"Importing {}"
.
format
(
name
))
print
(
"Importing {}"
.
format
(
name
))
...
@@ -1338,7 +1340,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1338,7 +1340,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
,
output_attentions
=
output_attentions
,
use_proj
=
use_proj
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
keep_multihead_output
=
keep_multihead_output
)
self
.
lo
s
s_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
if
not
is_regression
else
1
)
self
.
lo
git
s_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
if
not
is_regression
else
1
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
...
@@ -1376,7 +1378,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1376,7 +1378,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output_all_encoded_layers
,
head_mask
)
output_all_encoded_layers
,
head_mask
)
output
=
self
.
sequence_summary
(
output
)
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
lo
s
s_proj
(
output
)
logits
=
self
.
lo
git
s_proj
(
output
)
if
target
is
not
None
:
if
target
is
not
None
:
if
self
.
is_regression
:
if
self
.
is_regression
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment