"vscode:/vscode.git/clone" did not exist on "b2505f7db7b727cddb748d80a4b76e5895a8ed85"
Commit d0adab2c authored by Chris's avatar Chris
Browse files

fn change; pytorch_model_dir required=False

parent a309459b
......@@ -22,7 +22,7 @@ import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
def convert_hf_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
def convert_pytorch_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
"""
:param model:BertModel Pytorch model instance to be converted
......@@ -107,4 +107,4 @@ if __name__ == "__main__":
model = BertModel(
config=BertConfig(args.config_file_path)
).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment