Unverified Commit 3a5de7d2 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix KDA output (#27905)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent bc4486d6
......@@ -259,7 +259,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
hidden_states: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
) -> None:
num_tokens = hidden_states.size(0)
q = self.q_proj(hidden_states)[0]
k = self.k_proj(hidden_states)[0]
......@@ -291,8 +291,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
)
core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
return self.o_proj(core_attn_out)[0]
output[:] = self.o_proj(core_attn_out)[0]
def _forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment