Unverified Commit 9dd02d85 authored by Siyuan Li's avatar Siyuan Li Committed by GitHub
Browse files

[Bug] Fix usage of `.transpose()` and `.view()` consecutively. (#11979)

parent f7b3ba82
...@@ -230,7 +230,7 @@ class MultiHeadAttention(nn.Module): ...@@ -230,7 +230,7 @@ class MultiHeadAttention(nn.Module):
value, value,
scale=self.scale) scale=self.scale)
out = out.transpose(1, 2) out = out.transpose(1, 2)
return out.view(bsz, q_len, -1) return out.reshape(bsz, q_len, -1)
def unified_attention( def unified_attention(
......
...@@ -271,7 +271,7 @@ class InternSdpaAttention(nn.Module): ...@@ -271,7 +271,7 @@ class InternSdpaAttention(nn.Module):
v = v.transpose(1, 2) v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1) x = x.transpose(1, 2).reshape(B, N, -1)
x = self.proj(x) x = self.proj(x)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment