Commit d0adab2c authored by Chris's avatar Chris
Browse files

fn change; pytorch_model_dir required=False

parent a309459b
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertConfig, BertModel from pytorch_pretrained_bert.modeling import BertConfig, BertModel
def convert_hf_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str): def convert_pytorch_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
""" """
:param model:BertModel Pytorch model instance to be converted :param model:BertModel Pytorch model instance to be converted
...@@ -107,4 +107,4 @@ if __name__ == "__main__": ...@@ -107,4 +107,4 @@ if __name__ == "__main__":
model = BertModel( model = BertModel(
config=BertConfig(args.config_file_path) config=BertConfig(args.config_file_path)
).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir) ).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir) convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment