"test/vscode:/vscode.git/clone" did not exist on "fc78640e00e39520fa7126789d23369d2f104d0c"
Unverified Commit e6cff60b authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2069 from huggingface/cleaner-pt-tf-conversion

clean up PT <=> TF conversion
parents 4b82c485 1d87b37d
......@@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None,
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
config=config,
state_dict=torch.load(pytorch_checkpoint_path,
map_location='cpu'))
state_dict=state_dict)
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad():
pto = pt_model(pt_inputs)
......@@ -139,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False):
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None:
......@@ -187,11 +188,13 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if os.path.isfile(model_shortcut_name):
model_shortcut_name = 'converted_model'
convert_pt_checkpoint_to_tf(model_type=model_type,
pytorch_checkpoint_path=model_file,
config_file=config_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
compare_with_pt_model=compare_with_pt_model)
if remove_cached_files:
os.remove(config_file)
os.remove(model_file)
......@@ -226,6 +229,9 @@ if __name__ == "__main__":
parser.add_argument("--use_cached_models",
action='store_true',
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
parser.add_argument("--remove_cached_files",
action='store_true',
help = "Remove pytorch models after conversion (save memory when converting in batches).")
parser.add_argument("--only_convert_finetuned_models",
action='store_true',
help = "Only convert finetuned models.")
......@@ -245,4 +251,5 @@ if __name__ == "__main__":
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models,
remove_cached_files=args.remove_cached_files,
only_convert_finetuned_models=args.only_convert_finetuned_models)
......@@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
if pretrained_model_name_or_path is not None and (
"albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path):
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
"https://github.com/google-research/google-research/issues/119 for more information.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment