"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7b2bd1fbbd50e57cf28013e2d0737912ecc0f2eb"
Unverified Commit faae8d82 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Update PT/Flax weight conversion after #24030 (#24556)



* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 33b5ef5c
...@@ -120,6 +120,16 @@ def rename_key_and_reshape_tensor( ...@@ -120,6 +120,16 @@ def rename_key_and_reshape_tensor(
if pt_tuple_key[-1] == "beta": if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor return renamed_pt_tuple_key, pt_tensor
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
name = None
if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
name = pt_tuple_key[-2] + "_g"
elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
name = pt_tuple_key[-2] + "_v"
if name is not None:
renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
return renamed_pt_tuple_key, pt_tensor
return pt_tuple_key, pt_tensor return pt_tuple_key, pt_tensor
...@@ -372,6 +382,24 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ...@@ -372,6 +382,24 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
else: else:
flax_key = ".".join(flax_key_tuple) flax_key = ".".join(flax_key_tuple)
# We also need to look at `pt_model_dict` and see if there are keys requiring further transformation.
special_pt_names = {}
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
for key in pt_model_dict:
key_components = key.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
key_to_check = ".".join(key_components)
special_pt_names[key_to_check] = key
if flax_key in special_pt_names:
flax_key = special_pt_names[flax_key]
if flax_key in pt_model_dict: if flax_key in pt_model_dict:
if flax_tensor.shape != pt_model_dict[flax_key].shape: if flax_tensor.shape != pt_model_dict[flax_key].shape:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment