Unverified Commit e6cff60b authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2069 from huggingface/cleaner-pt-tf-conversion

clean up PT <=> TF conversion
parents 4b82c485 1d87b37d
...@@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -119,10 +119,11 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_inputs = tf.constant(inputs_list) tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None, state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
config=config, config=config,
state_dict=torch.load(pytorch_checkpoint_path, state_dict=state_dict)
map_location='cpu'))
pt_inputs = torch.tensor(inputs_list) pt_inputs = torch.tensor(inputs_list)
with torch.no_grad(): with torch.no_grad():
pto = pt_model(pt_inputs) pto = pt_model(pt_inputs)
...@@ -139,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -139,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None, def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False): compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory" assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None: if args_model_type is None:
...@@ -187,11 +188,13 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc ...@@ -187,11 +188,13 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if os.path.isfile(model_shortcut_name): if os.path.isfile(model_shortcut_name):
model_shortcut_name = 'converted_model' model_shortcut_name = 'converted_model'
convert_pt_checkpoint_to_tf(model_type=model_type, convert_pt_checkpoint_to_tf(model_type=model_type,
pytorch_checkpoint_path=model_file, pytorch_checkpoint_path=model_file,
config_file=config_file, config_file=config_file,
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
compare_with_pt_model=compare_with_pt_model) compare_with_pt_model=compare_with_pt_model)
if remove_cached_files:
os.remove(config_file) os.remove(config_file)
os.remove(model_file) os.remove(model_file)
...@@ -226,6 +229,9 @@ if __name__ == "__main__": ...@@ -226,6 +229,9 @@ if __name__ == "__main__":
parser.add_argument("--use_cached_models", parser.add_argument("--use_cached_models",
action='store_true', action='store_true',
help = "Use cached models if possible instead of updating to latest checkpoint versions.") help = "Use cached models if possible instead of updating to latest checkpoint versions.")
parser.add_argument("--remove_cached_files",
action='store_true',
help = "Remove pytorch models after conversion (save memory when converting in batches).")
parser.add_argument("--only_convert_finetuned_models", parser.add_argument("--only_convert_finetuned_models",
action='store_true', action='store_true',
help = "Only convert finetuned models.") help = "Only convert finetuned models.")
...@@ -245,4 +251,5 @@ if __name__ == "__main__": ...@@ -245,4 +251,5 @@ if __name__ == "__main__":
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
compare_with_pt_model=args.compare_with_pt_model, compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models, use_cached_models=args.use_cached_models,
remove_cached_files=args.remove_cached_files,
only_convert_finetuned_models=args.only_convert_finetuned_models) only_convert_finetuned_models=args.only_convert_finetuned_models)
...@@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module): ...@@ -318,7 +318,8 @@ class PreTrainedModel(nn.Module):
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path: if pretrained_model_name_or_path is not None and (
"albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path):
logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " + logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
"https://github.com/google-research/google-research/issues/119 for more information.") "https://github.com/google-research/google-research/issues/119 for more information.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment