Unverified Commit 6c57ce15 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Update PT/TF weight conversion after #24030 (#24547)



* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent c5e29d43
...@@ -273,6 +273,18 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -273,6 +273,18 @@ def load_pytorch_state_dict_in_tf2_model(
new_key = key.replace("running_var", "moving_variance") new_key = key.replace("running_var", "moving_variance")
if "running_mean" in key: if "running_mean" in key:
new_key = key.replace("running_mean", "moving_mean") new_key = key.replace("running_mean", "moving_mean")
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
key_components = key.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
new_key = ".".join(key_components)
if new_key is None: if new_key is None:
new_key = key new_key = key
tf_keys_to_pt_keys[new_key] = key tf_keys_to_pt_keys[new_key] = key
...@@ -499,15 +511,27 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_ ...@@ -499,15 +511,27 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()] new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
continue continue
pt_weight_name_to_check = pt_weight_name
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
key_components = pt_weight_name.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
pt_weight_name_to_check = ".".join(key_components)
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
if pt_weight_name not in tf_weights_map: if pt_weight_name_to_check not in tf_weights_map:
if allow_missing_keys: if allow_missing_keys:
missing_keys_pt.append(pt_weight_name) missing_keys_pt.append(pt_weight_name)
continue continue
raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model") raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")
array, transpose = tf_weights_map[pt_weight_name] array, transpose = tf_weights_map[pt_weight_name_to_check]
array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False) array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment