Unverified Commit 56b8d49d authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Fix ConvNext V2 paramater naming issue (#23122)

Fixes the parameter naming issue in ConvNextV2GRN module
parent b53004fd
......@@ -99,6 +99,10 @@ def rename_key(name):
if "stages" in name and "downsampling_layer" not in name:
# stages.0.0. for instance should be renamed to stages.0.layers.0.
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
if "gamma" in name:
name = name.replace("gamma", "weight")
if "beta" in name:
name = name.replace("beta", "bias")
if "stages" in name:
name = name.replace("stages", "encoder.stages")
if "norm" in name:
......
......@@ -100,14 +100,14 @@ class ConvNextV2GRN(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# Compute and normalize global spatial feature maps
global_features = torch.norm(hidden_states, p=2, dim=(1, 2), keepdim=True)
norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
hidden_states = self.gamma * (hidden_states * norm_features) + self.beta + hidden_states
hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment