Commit 716cc1c4 authored by chrislarson1's avatar chrislarson1
Browse files

added main() for programmatic call to convert pytorch->tf

parent a8e071c6
...@@ -17,16 +17,18 @@ ...@@ -17,16 +17,18 @@
import os import os
import argparse import argparse
import torch
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertConfig, BertModel from pytorch_pretrained_bert.modeling import BertModel
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str): def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
""" """
:param model:BertModel Pytorch model instance to be converted :param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: directory to save Tensorflow model :param ckpt_dir: Tensorflow model directory
:param model_name: model name
:return: :return:
Currently supported HF models: Currently supported HF models:
...@@ -87,35 +89,42 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str): ...@@ -87,35 +89,42 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name)))) print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))
saver = tf.train.Saver(tf_vars) saver = tf.train.Saver(tf_vars)
saver.save(session, os.path.join(ckpt_dir, args.pytorch_model_name)) saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
if __name__ == "__main__": def main(raw_args=None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_model_dir", parser.add_argument("--model_name",
default=None,
type=str,
required=False,
help="Directory containing pytorch model")
parser.add_argument("--pytorch_model_name",
default=None,
type=str, type=str,
required=True, required=True,
help="model name (e.g. bert-base-uncased)") help="model name e.g. bert-base-uncased")
parser.add_argument("--config_file_path", parser.add_argument("--cache_dir",
type=str,
default=None, default=None,
required=False,
help="Directory containing pytorch model")
parser.add_argument("--pytorch_model_path",
type=str, type=str,
required=True, required=True,
help="Path to bert config file") help="/path/to/<pytorch-model-name>.bin")
parser.add_argument("--tf_checkpoint_dir", parser.add_argument("--tf_cache_dir",
default="",
type=str, type=str,
required=True, required=True,
help="Directory in which to save tensorflow model") help="Directory in which to save tensorflow model")
args = parser.parse_args() args = parser.parse_args(raw_args)
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir
)
convert_pytorch_checkpoint_to_tf(
model=model,
ckpt_dir=args.tf_cache_dir,
model_name=args.model_name
)
model = BertModel( if __name__ == "__main__":
config=BertConfig(args.config_file_path) main()
).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment