Unverified Commit f9f9f746 authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge pull request #315 from kvcache-ai/Atream-add-adapted

Atream add adapted
parents 1548c992 92399283
......@@ -262,7 +262,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""
# flash attn doesn't support head_dim bigger than 256
# use vLLM triton attention kernel for MQA
# use triton attention kernel adapted from vLLM and SGLang for MQA
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
page_table,
position_ids.squeeze(0).to(torch.int32), attn_logits,
......@@ -551,4 +551,4 @@ class KLlamaAttention(BaseInjectedModule):
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
\ No newline at end of file
return attn_output, attn_weights, past_key_value
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
import triton
import triton.language as tl
......@@ -376,4 +382,4 @@ def decode_attention_fwd_grouped(
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
num_kv_splits)
\ No newline at end of file
num_kv_splits)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment