"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e78c1103385f2d2f9cd4980f61a8e71baa655356"
Commit 6a083fd4 authored by thomwolf's avatar thomwolf
Browse files

update pt-tf conversion script

parent f6969cc1
...@@ -102,7 +102,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -102,7 +102,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model.save_weights(tf_dump_path, save_format='h5') tf_model.save_weights(tf_dump_path, save_format='h5')
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False): def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False, use_cached_models=False):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory" assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None: if args_model_type is None:
...@@ -126,8 +126,8 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with ...@@ -126,8 +126,8 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
if 'finetuned' in shortcut_name: if 'finetuned' in shortcut_name:
print(" Skipping finetuned checkpoint ") print(" Skipping finetuned checkpoint ")
continue continue
config_file = cached_path(aws_config_map[shortcut_name], force_download=True) config_file = cached_path(aws_config_map[shortcut_name], force_download=not use_cached_models)
model_file = cached_path(aws_model_maps[shortcut_name], force_download=True) model_file = cached_path(aws_model_maps[shortcut_name], force_download=not use_cached_models)
convert_pt_checkpoint_to_tf(model_type, convert_pt_checkpoint_to_tf(model_type,
model_file, model_file,
...@@ -165,6 +165,9 @@ if __name__ == "__main__": ...@@ -165,6 +165,9 @@ if __name__ == "__main__":
parser.add_argument("--compare_with_pt_model", parser.add_argument("--compare_with_pt_model",
action='store_true', action='store_true',
help = "Compare Tensorflow and PyTorch model predictions.") help = "Compare Tensorflow and PyTorch model predictions.")
parser.add_argument("--use_cached_models",
action='store_true',
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
args = parser.parse_args() args = parser.parse_args()
if args.pytorch_checkpoint_path is not None: if args.pytorch_checkpoint_path is not None:
...@@ -176,4 +179,5 @@ if __name__ == "__main__": ...@@ -176,4 +179,5 @@ if __name__ == "__main__":
else: else:
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path, args.tf_dump_path,
compare_with_pt_model=args.compare_with_pt_model) compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models)
...@@ -78,6 +78,12 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None ...@@ -78,6 +78,12 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
for old_key, new_key in zip(old_keys, new_keys): for old_key, new_key in zip(old_keys, new_keys):
pt_state_dict[new_key] = pt_state_dict.pop(old_key) pt_state_dict[new_key] = pt_state_dict.pop(old_key)
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
# TF models always have a prefix, some of PyTorch models (base ones) don't
start_prefix_to_remove = ''
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
start_prefix_to_remove = tf_model.base_model_prefix + '.'
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = [] weight_value_tuples = []
...@@ -100,13 +106,23 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None ...@@ -100,13 +106,23 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
if name[-1] == 'beta': if name[-1] == 'beta':
name[-1] = 'bias' name[-1] = 'bias'
# Remove prefix if needed
name = '.'.join(name) name = '.'.join(name)
if start_prefix_to_remove:
name = name.replace(start_prefix_to_remove, '', 1)
# Find associated numpy array in pytorch model state dict
assert name in pt_state_dict, "{} not found in PyTorch model".format(name) assert name in pt_state_dict, "{} not found in PyTorch model".format(name)
array = pt_state_dict[name].numpy() array = pt_state_dict[name].numpy()
if transpose: if transpose:
array = numpy.transpose(array) array = numpy.transpose(array)
if len(symbolic_weight.shape) < len(array.shape):
array = numpy.squeeze(array)
elif len(symbolic_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)
try: try:
assert list(symbolic_weight.shape) == list(array.shape) assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e: except AssertionError as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment