Unverified Commit 92399283 authored by Atream's avatar Atream Committed by GitHub
Browse files

Update attention.py

parent d90749d3
...@@ -262,7 +262,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -262,7 +262,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
""" """
# flash attn doesn't support head_dim bigger than 256 # flash attn doesn't support head_dim bigger than 256
# use vLLM triton attention kernel for MQA # use triton attention kernel adapted from vLLM and SGLang for MQA
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output, decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
page_table, page_table,
position_ids.squeeze(0).to(torch.int32), attn_logits, position_ids.squeeze(0).to(torch.int32), attn_logits,
...@@ -551,4 +551,4 @@ class KLlamaAttention(BaseInjectedModule): ...@@ -551,4 +551,4 @@ class KLlamaAttention(BaseInjectedModule):
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment