Commit 0a8b4d65 authored by Chris's avatar Chris
Browse files

added file to convert pytorch->tf

parent 968c1b44
...@@ -21,7 +21,6 @@ import numpy as np ...@@ -21,7 +21,6 @@ import numpy as np
from pytorch_pretrained_bert.modeling import BertConfig, BertModel from pytorch_pretrained_bert.modeling import BertConfig, BertModel
def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str): def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
""" """
...@@ -129,4 +128,4 @@ if __name__ == "__main__": ...@@ -129,4 +128,4 @@ if __name__ == "__main__":
config=BertConfig(args.config_file_path) config=BertConfig(args.config_file_path)
).from_pretrained(args.pretrained_model_name_or_path) ).from_pretrained(args.pretrained_model_name_or_path)
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.cache_) convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.cache_)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment