"...git@developer.sourcefind.cn:modelzoo/llama_paddle.git" did not exist on "10f294ff34d709a74db087cddd7127691a439d49"
Unverified Commit a8fa6fbb authored by Thien Tran's avatar Thien Tran Committed by GitHub
Browse files

Fix Wav2Vec2 Fairseq conversion (weight norm state dict keys) (#31714)

* handle new weight norm

* fix

* fix trailing space
parent a01b033c
...@@ -94,7 +94,16 @@ def set_recursively(key, value, full_name, weight_type, hf_pointer): ...@@ -94,7 +94,16 @@ def set_recursively(key, value, full_name, weight_type, hf_pointer):
hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]] hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
weight_type = "param" weight_type = "param"
# fairseq uses nn.utils.weight_norm() while transformers switches to nn.utils.parametrizations.weight_norm()
# the mapping between two versions:
# https://github.com/pytorch/pytorch/blob/56935684c3dfad7841c83c719eeebecb560fe466/torch/nn/utils/parametrizations.py#L389-L395
if weight_type is not None and weight_type != "param": if weight_type is not None and weight_type != "param":
if weight_type == "weight_g" and not hasattr(hf_pointer, "weight_g"):
hf_shape = hf_pointer.parametrizations.weight.original0.shape
elif weight_type == "weight_v" and not hasattr(hf_pointer, "weight_v"):
hf_shape = hf_pointer.parametrizations.weight.original1.shape
else:
hf_shape = getattr(hf_pointer, weight_type).shape hf_shape = getattr(hf_pointer, weight_type).shape
elif weight_type is not None and weight_type == "param": elif weight_type is not None and weight_type == "param":
shape_pointer = hf_pointer shape_pointer = hf_pointer
...@@ -116,9 +125,15 @@ def set_recursively(key, value, full_name, weight_type, hf_pointer): ...@@ -116,9 +125,15 @@ def set_recursively(key, value, full_name, weight_type, hf_pointer):
if weight_type == "weight": if weight_type == "weight":
hf_pointer.weight.data = value hf_pointer.weight.data = value
elif weight_type == "weight_g": elif weight_type == "weight_g":
if hasattr(hf_pointer, "weight_g"):
hf_pointer.weight_g.data = value hf_pointer.weight_g.data = value
else:
hf_pointer.parametrizations.weight.original0.data = value
elif weight_type == "weight_v": elif weight_type == "weight_v":
if hasattr(hf_pointer, "weight_v"):
hf_pointer.weight_v.data = value hf_pointer.weight_v.data = value
else:
hf_pointer.parametrizations.weight.original1.data = value
elif weight_type == "bias": elif weight_type == "bias":
hf_pointer.bias.data = value hf_pointer.bias.data = value
elif weight_type == "param": elif weight_type == "param":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment