Commit a84adddd authored by thomwolf's avatar thomwolf
Browse files

convert all models

parent 969d3ae9
...@@ -18,10 +18,11 @@ from __future__ import absolute_import ...@@ -18,10 +18,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import argparse import argparse
import tensorflow as tf import tensorflow as tf
from pytorch_transformers import is_torch_available from pytorch_transformers import is_torch_available, cached_path
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2,
...@@ -31,26 +32,36 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt ...@@ -31,26 +32,36 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt
if is_torch_available(): if is_torch_available():
import torch import torch
import numpy as np import numpy as np
from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel, XLNetLMHeadModel, XLMWithLMHeadModel from pytorch_transformers import (BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,)
else: else:
BertForPreTraining, GPT2LMHeadModel = None, None (BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,) = (
None, None, None,
None, None, None,
None, None, None,
None, None, None,)
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining), 'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel), 'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel), 'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel), 'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
} }
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False): def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
if model_type not in MODEL_CLASSES: if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys()))) raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class = MODEL_CLASSES[model_type] config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
# Initialise TF model # Initialise TF model
config = config_class.from_json_file(config_file) config = config_class.from_json_file(config_file)
...@@ -68,8 +79,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -68,8 +79,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tfo = tf_model(tf_inputs, training=False) # build the network tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None, pt_model = pt_model_class.from_pretrained(None,
config=config, config=config,
state_dict=torch.load(pytorch_checkpoint_path, state_dict=torch.load(pytorch_checkpoint_path,
map_location='cpu')) map_location='cpu'))
pt_inputs = torch.tensor(inputs_list) pt_inputs = torch.tensor(inputs_list)
with torch.no_grad(): with torch.no_grad():
...@@ -79,42 +90,80 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -79,42 +90,80 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
np_tf = tfo[0].numpy() np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf)) diff = np.amax(np.abs(np_pt - np_tf))
print("Max absolute difference between models outputs {}".format(diff)) print("Max absolute difference between models outputs {}".format(diff))
assert diff <= 1e-3, "Error, model absolute difference is >1e-3"
# Save pytorch-model # Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path)) print("Save TensorFlow model to {}".format(tf_dump_path))
tf_model.save_weights(tf_dump_path) tf_model.save_weights(tf_dump_path, save_format='h5')
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None:
model_types = list(MODEL_CLASSES.keys())
else:
model_types = [args_model_type]
for j, model_type in enumerate(model_types, start=1):
print("=" * 100)
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
print("=" * 100)
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
for i, shortcut_name in enumerate(aws_config_map.keys(), start=1):
print("-" * 100)
print(" Converting checkpoint {}/{}: {}".format(i, len(aws_config_map), shortcut_name))
print("-" * 100)
config_file = cached_path(aws_config_map[shortcut_name], force_download=True)
model_file = cached_path(aws_model_maps[shortcut_name], force_download=True)
convert_pt_checkpoint_to_tf(model_type,
model_file,
config_file,
os.path.join(tf_dump_path, shortcut_name + '-tf_model.h5'),
compare_with_pt_model=compare_with_pt_model)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--model_type", parser.add_argument("--tf_dump_path",
default = None, default = None,
type = str, type = str,
required = True, required = True,
help = "Model type selcted in the list of {}.".format(list(MODEL_CLASSES.keys()))) help = "Path to the output Tensorflow dump file.")
parser.add_argument("--model_type",
default = None,
type = str,
help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys())))
parser.add_argument("--pytorch_checkpoint_path", parser.add_argument("--pytorch_checkpoint_path",
default = None, default = None,
type = str, type = str,
required = True, help = "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
help = "Path to the PyTorch checkpoint path.") "If not given, will download and convert all the checkpoints from AWS.")
parser.add_argument("--config_file", parser.add_argument("--config_file",
default = None, default = None,
type = str, type = str,
required = True,
help = "The config json file corresponding to the pre-trained model. \n" help = "The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture.") "This specifies the model architecture. If not given and "
parser.add_argument("--tf_dump_path", "--pytorch_checkpoint_path is not given or is a shortcut name"
default = None, "use the configuration associated to teh shortcut name on the AWS")
type = str,
required = True,
help = "Path to the output Tensorflow dump file.")
parser.add_argument("--compare_with_pt_model", parser.add_argument("--compare_with_pt_model",
action='store_true', action='store_true',
help = "Compare Tensorflow and PyTorch model predictions.") help = "Compare Tensorflow and PyTorch model predictions.")
args = parser.parse_args() args = parser.parse_args()
convert_pt_checkpoint_to_tf(args.model_type.lower(),
args.pytorch_checkpoint_path, if args.pytorch_checkpoint_path is not None:
args.config_file, convert_pt_checkpoint_to_tf(args.model_type.lower(),
args.tf_dump_path, args.pytorch_checkpoint_path,
compare_with_pt_model=args.compare_with_pt_model) args.config_file,
args.tf_dump_path,
compare_with_pt_model=args.compare_with_pt_model)
else:
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path,
compare_with_pt_model=args.compare_with_pt_model)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment