"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "48a05026def1e94ae08037a252472c030409857e"
Commit 9f81f1cb authored by VictorSanh's avatar VictorSanh
Browse files

fix convert pt_to_tf2 for custom weights

parent 1615360c
...@@ -173,10 +173,12 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -173,10 +173,12 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
else: else:
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
convert_pt_checkpoint_to_tf(model_type, if os.path.isfile(model_shortcut_name):
model_file, model_shortcut_name = 'converted_model'
config_file, convert_pt_checkpoint_to_tf(model_type=model_type,
os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), pytorch_checkpoint_path=model_file,
config_file=config_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
compare_with_pt_model=compare_with_pt_model) compare_with_pt_model=compare_with_pt_model)
os.remove(config_file) os.remove(config_file)
os.remove(model_file) os.remove(model_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment