Unverified Commit 2dd34371 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix RMSNorm forward in InternViT attention qk_layernorm (#6992)

parent 7e0861bd
...@@ -113,10 +113,10 @@ class InternAttention(nn.Module): ...@@ -113,10 +113,10 @@ class InternAttention(nn.Module):
if self.qk_normalization: if self.qk_normalization:
B_, H_, N_, D_ = q.shape B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view( q = self.q_norm.forward_native(q.transpose(1, 2).flatten(
B_, N_, H_, D_).transpose(1, 2) -2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view( k = self.k_norm.forward_native(k.transpose(1, 2).flatten(
B_, N_, H_, D_).transpose(1, 2) -2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, C) x = x.transpose(1, 2).reshape(B, N, C)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment