Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7de17404
Commit
7de17404
authored
Jun 25, 2019
by
thomwolf
Browse files
add ability to restore fine-tuned TF mdoel
parent
7334bf6c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
15 deletions
+33
-15
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+23
-7
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+10
-8
No files found.
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
7de17404
...
@@ -24,16 +24,27 @@ import torch
...
@@ -24,16 +24,27 @@ import torch
from
pytorch_pretrained_bert.modeling_xlnet
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
from
pytorch_pretrained_bert.modeling_xlnet
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
XLNetConfig
,
XLNetRunConfig
,
XLNetConfig
,
XLNetRunConfig
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
)
XLNetLMHeadModel
,
XLNetForQuestionAnswering
,
XLNetForSequenceClassification
,
load_tf_weights_in_xlnet
)
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_folder_path
):
GLUE_TASKS
=
[
"cola"
,
"mnli"
,
"mnli-mm"
,
"mrpc"
,
"sst-2"
,
"sts-b"
,
"qqp"
,
"qnli"
,
"rte"
,
"wnli"
]
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_folder_path
,
finetuning_task
=
None
):
# Initialise PyTorch model
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
if
finetuning_task
is
not
None
and
finetuning_task
.
lower
()
in
GLUE_TASKS
:
model
=
XLNetLMHeadModel
(
config
)
model_class
=
XLNetLMHeadModel
elif
finetuning_task
is
not
None
and
'squad'
in
finetuning_task
.
lower
():
model_class
=
XLNetForQuestionAnswering
else
:
model_class
=
XLNetLMHeadModel
print
(
"Building PyTorch model {} from configuration: {}"
.
format
(
str
(
model_class
),
str
(
config
)))
model
=
model_class
(
config
)
# Load weights from tf checkpoint
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
)
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
,
finetuning_task
)
# Save pytorch-model
# Save pytorch-model
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
...
@@ -59,12 +70,17 @@ if __name__ == "__main__":
...
@@ -59,12 +70,17 @@ if __name__ == "__main__":
required
=
True
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
finetuning_task
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
)
help
=
"Path to the folder to store the PyTorch model or dataset/vocab."
)
parser
.
add_argument
(
"--finetuning_task"
,
default
=
None
,
type
=
str
,
help
=
"Name of a task on which the XLNet TensorFloaw model was fine-tuned"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_folder_path
)
args
.
pytorch_dump_folder_path
,
args
.
finetuning_task
)
pytorch_pretrained_bert/modeling_xlnet.py
View file @
7de17404
...
@@ -46,7 +46,7 @@ XLNET_CONFIG_NAME = 'xlnet_config.json'
...
@@ -46,7 +46,7 @@ XLNET_CONFIG_NAME = 'xlnet_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
):
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
,
finetuning_task
=
None
):
""" A map of modules from TF to PyTorch.
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
identical to the original PyTorch model as possible.
...
@@ -62,8 +62,10 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
...
@@ -62,8 +62,10 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
# We will load also the sequence summary
# We will load also the sequence summary
tf_to_pt_map
[
'model/sequnece_summary/summary/kernel'
]
=
model
.
sequence_summary
.
summary
.
weight
tf_to_pt_map
[
'model/sequnece_summary/summary/kernel'
]
=
model
.
sequence_summary
.
summary
.
weight
tf_to_pt_map
[
'model/sequnece_summary/summary/bias'
]
=
model
.
sequence_summary
.
summary
.
bias
tf_to_pt_map
[
'model/sequnece_summary/summary/bias'
]
=
model
.
sequence_summary
.
summary
.
bias
elif
hasattr
(
model
,
'proj_loss'
)
and
any
(
'model/regression'
in
name
for
name
in
tf_weights
.
keys
()):
elif
hasattr
(
model
,
'logits_proj'
)
and
finetuning_task
is
not
None
and
any
(
'model/regression'
in
name
for
name
in
tf_weights
.
keys
()):
raise
NotImplementedError
tf_to_pt_map
[
'model/regression_{}/logit/kernel'
.
format
(
finetuning_task
)]
=
model
.
logits_proj
.
weight
tf_to_pt_map
[
'model/regression_{}/logit/bias'
.
format
(
finetuning_task
)]
=
model
.
logits_proj
.
bias
# Now load the rest of the transformer
# Now load the rest of the transformer
model
=
model
.
transformer
model
=
model
.
transformer
...
@@ -113,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
...
@@ -113,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
'model/transformer/seg_embed'
:
seg_embed_list
})
'model/transformer/seg_embed'
:
seg_embed_list
})
return
tf_to_pt_map
return
tf_to_pt_map
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
):
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
,
finetuning_task
=
None
):
""" Load tf checkpoints in a pytorch model
""" Load tf checkpoints in a pytorch model
"""
"""
try
:
try
:
...
@@ -132,7 +134,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
...
@@ -132,7 +134,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
tf_weights
[
name
]
=
array
tf_weights
[
name
]
=
array
# Build TF to PyTorch weights loading map
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
)
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
,
finetuning_task
)
for
name
,
pointer
in
tf_to_pt_map
.
items
():
for
name
,
pointer
in
tf_to_pt_map
.
items
():
print
(
"Importing {}"
.
format
(
name
))
print
(
"Importing {}"
.
format
(
name
))
...
@@ -1338,7 +1340,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1338,7 +1340,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
,
output_attentions
=
output_attentions
,
use_proj
=
use_proj
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
keep_multihead_output
=
keep_multihead_output
)
self
.
lo
s
s_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
if
not
is_regression
else
1
)
self
.
lo
git
s_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
if
not
is_regression
else
1
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
...
@@ -1376,7 +1378,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1376,7 +1378,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output_all_encoded_layers
,
head_mask
)
output_all_encoded_layers
,
head_mask
)
output
=
self
.
sequence_summary
(
output
)
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
lo
s
s_proj
(
output
)
logits
=
self
.
lo
git
s_proj
(
output
)
if
target
is
not
None
:
if
target
is
not
None
:
if
self
.
is_regression
:
if
self
.
is_regression
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment