Unverified Commit 3b9e16ce authored by Atream's avatar Atream Committed by GitHub
Browse files

Update attention.py

parent 94476ce5
...@@ -435,6 +435,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -435,6 +435,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device) kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
self.mla_wrapper.plan(qo_indptr,None,None, self.mla_wrapper.plan(qo_indptr,None,None,
kv_len_arr, kv_len_arr,
None,
self.num_heads, self.num_heads,
self.kv_lora_rank, self.kv_lora_rank,
self.qk_rope_head_dim, self.qk_rope_head_dim,
...@@ -849,4 +850,4 @@ class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention): ...@@ -849,4 +850,4 @@ class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
attn_output = attn_output.transpose(0, 1) attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim) attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, num_tokens_tensors) attn_output = self.o_proj(attn_output, num_tokens_tensors)
return attn_output return attn_output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment