Unverified Commit 3a5de7d2 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix KDA output (#27905)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent bc4486d6
...@@ -259,7 +259,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -259,7 +259,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
) -> torch.Tensor: ) -> None:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
q = self.q_proj(hidden_states)[0] q = self.q_proj(hidden_states)[0]
k = self.k_proj(hidden_states)[0] k = self.k_proj(hidden_states)[0]
...@@ -291,8 +291,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ...@@ -291,8 +291,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
) )
core_attn_out = self.o_norm(core_attn_out, g2) core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]
return self.o_proj(core_attn_out)[0]
def _forward( def _forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment