Commit 7de17404 authored by thomwolf's avatar thomwolf
Browse files

add ability to restore fine-tuned TF mdoel

parent 7334bf6c
...@@ -24,16 +24,27 @@ import torch ...@@ -24,16 +24,27 @@ import torch
from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME, from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
XLNetConfig, XLNetRunConfig, XLNetConfig, XLNetRunConfig,
XLNetLMHeadModel, load_tf_weights_in_xlnet) XLNetLMHeadModel, XLNetForQuestionAnswering,
XLNetForSequenceClassification,
load_tf_weights_in_xlnet)
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path): GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "sst-2", "sts-b", "qqp", "qnli", "rte", "wnli"]
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
# Initialise PyTorch model # Initialise PyTorch model
config = XLNetConfig.from_json_file(bert_config_file) config = XLNetConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config))) if finetuning_task is not None and finetuning_task.lower() in GLUE_TASKS:
model = XLNetLMHeadModel(config) model_class = XLNetLMHeadModel
elif finetuning_task is not None and 'squad' in finetuning_task.lower():
model_class = XLNetForQuestionAnswering
else:
model_class = XLNetLMHeadModel
print("Building PyTorch model {} from configuration: {}".format(str(model_class), str(config)))
model = model_class(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) load_tf_weights_in_xlnet(model, config, tf_checkpoint_path, finetuning_task)
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
...@@ -59,12 +70,17 @@ if __name__ == "__main__": ...@@ -59,12 +70,17 @@ if __name__ == "__main__":
required = True, required = True,
help = "The config json file corresponding to the pre-trained XLNet model. \n" help = "The config json file corresponding to the pre-trained XLNet model. \n"
"This specifies the model architecture.") "This specifies the model architecture.")
parser.add_argument("--pytorch_dump_folder_path", parser.add_argument("--pytorch_dump_folder_path",finetuning_task
default = None, default = None,
type = str, type = str,
required = True, required = True,
help = "Path to the folder to store the PyTorch model or dataset/vocab.") help = "Path to the folder to store the PyTorch model or dataset/vocab.")
parser.add_argument("--finetuning_task",
default = None,
type = str,
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
args = parser.parse_args() args = parser.parse_args()
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.xlnet_config_file, args.xlnet_config_file,
args.pytorch_dump_folder_path) args.pytorch_dump_folder_path,
args.finetuning_task)
...@@ -46,7 +46,7 @@ XLNET_CONFIG_NAME = 'xlnet_config.json' ...@@ -46,7 +46,7 @@ XLNET_CONFIG_NAME = 'xlnet_config.json'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None): def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_task=None):
""" A map of modules from TF to PyTorch. """ A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible. identical to the original PyTorch model as possible.
...@@ -62,14 +62,16 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None): ...@@ -62,14 +62,16 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
# We will load also the sequence summary # We will load also the sequence summary
tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight tf_to_pt_map['model/sequnece_summary/summary/kernel'] = model.sequence_summary.summary.weight
tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias tf_to_pt_map['model/sequnece_summary/summary/bias'] = model.sequence_summary.summary.bias
elif hasattr(model, 'proj_loss') and any('model/regression' in name for name in tf_weights.keys()): elif hasattr(model, 'logits_proj') and finetuning_task is not None and any('model/regression' in name for name in tf_weights.keys()):
raise NotImplementedError tf_to_pt_map['model/regression_{}/logit/kernel'.format(finetuning_task)] = model.logits_proj.weight
tf_to_pt_map['model/regression_{}/logit/bias'.format(finetuning_task)] = model.logits_proj.bias
# Now load the rest of the transformer # Now load the rest of the transformer
model = model.transformer model = model.transformer
# Embeddings and output # Embeddings and output
tf_to_pt_map.update({'model/transformer/word_embedding/lookup_table': model.word_embedding.weight, tf_to_pt_map.update({'model/transformer/word_embedding/lookup_table': model.word_embedding.weight,
'model/transformer/mask_emb/mask_emb': model.mask_emb}) 'model/transformer/mask_emb/mask_emb': model.mask_emb})
# Transformer blocks # Transformer blocks
for i, b in enumerate(model.layer): for i, b in enumerate(model.layer):
...@@ -113,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None): ...@@ -113,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
'model/transformer/seg_embed': seg_embed_list}) 'model/transformer/seg_embed': seg_embed_list})
return tf_to_pt_map return tf_to_pt_map
def load_tf_weights_in_xlnet(model, config, tf_path): def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
""" """
try: try:
...@@ -132,7 +134,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path): ...@@ -132,7 +134,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
tf_weights[name] = array tf_weights[name] = array
# Build TF to PyTorch weights loading map # Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights) tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights, finetuning_task)
for name, pointer in tf_to_pt_map.items(): for name, pointer in tf_to_pt_map.items():
print("Importing {}".format(name)) print("Importing {}".format(name))
...@@ -1338,7 +1340,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1338,7 +1340,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type, self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
use_proj=use_proj, output_attentions=output_attentions, use_proj=use_proj, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.loss_proj = nn.Linear(config.d_model, num_labels if not is_regression else 1) self.logits_proj = nn.Linear(config.d_model, num_labels if not is_regression else 1)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
...@@ -1376,7 +1378,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1376,7 +1378,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output_all_encoded_layers, head_mask) output_all_encoded_layers, head_mask)
output = self.sequence_summary(output) output = self.sequence_summary(output)
logits = self.loss_proj(output) logits = self.logits_proj(output)
if target is not None: if target is not None:
if self.is_regression: if self.is_regression:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment