"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "44286b94d3376f56ee7ef039790d40798d5f9e7d"
Commit d0adab2c authored by Chris's avatar Chris
Browse files

fn change; pytorch_model_dir required=False

parent a309459b
......@@ -22,7 +22,7 @@ import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
def convert_hf_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
def convert_pytorch_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
"""
:param model:BertModel Pytorch model instance to be converted
......@@ -107,4 +107,4 @@ if __name__ == "__main__":
model = BertModel(
config=BertConfig(args.config_file_path)
).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment