Unverified Commit f235498e authored by YAMY's avatar YAMY Committed by GitHub
Browse files

DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill (#11892)

parent 149dc9aa
......@@ -242,6 +242,30 @@ class Indexer(CustomOp):
return query, key, weights
def _get_k_bf16(
self,
x: torch.Tensor,
positions: torch.Tensor,
enable_dual_stream: bool,
):
# Compute only key, skip query and weights (weights is discarded if fused)
if self.fuse_wk_and_weights_proj:
key, _ = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
_, k_rope = self.rotary_emb(positions, k_rope, k_rope)
key[..., : self.rope_head_dim] = k_rope
key = rotate_activation(key)
return key
def _get_topk_paged(
self,
forward_batch: ForwardBatch,
......@@ -375,6 +399,45 @@ class Indexer(CustomOp):
topk_result[:offset] = raw_topk_result
return topk_result
def _forward_cuda_k_only(
self,
x: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
act_quant,
enable_dual_stream: bool,
metadata: BaseIndexerMetadata,
return_indices: bool = True,
) -> Optional[torch.Tensor]:
# Fast path: only compute and store k cache, skip all q and weights ops
key = self._get_k_bf16(x, positions, enable_dual_stream)
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
if not forward_batch.out_cache_loc.is_contiguous():
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
index_k=k_fp8,
index_k_scale=k_scale,
)
# MHA doesn't need topk_indices
if not return_indices:
return None
# MLA: use dummy logits with topk kernel's fast path to generate indices
# When length <= 2048, naive_topk_cuda directly generates [0,1,...,length-1,-1,...]
seq_lens_expanded = metadata.get_seqlens_expanded()
dummy_logits = torch.zeros(
seq_lens_expanded.shape[0],
self.index_topk,
dtype=torch.float32,
device=x.device,
)
return metadata.topk_transform(dummy_logits, self.index_topk)
def forward_indexer(
self,
q_fp8: torch.Tensor,
......@@ -465,6 +528,7 @@ class Indexer(CustomOp):
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
return_indices: bool = True,
) -> Optional[torch.Tensor]:
if is_hip():
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
......@@ -490,6 +554,26 @@ class Indexer(CustomOp):
if metadata is None:
return None
# Determine if should skip topk based on sequence length
should_skip = False
if not forward_batch.forward_mode.is_decode_or_idle():
if forward_batch.seq_lens_cpu is not None:
max_kv_len = forward_batch.seq_lens_cpu.max().item()
should_skip = max_kv_len <= self.index_topk
# Optimization: fast path when skipping topk computation
if should_skip:
return self._forward_cuda_k_only(
x,
positions,
forward_batch,
layer_id,
act_quant,
enable_dual_stream,
metadata,
return_indices,
)
query, key, weights = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream
)
......
......@@ -47,7 +47,7 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
else:
from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass(frozen=True)
......@@ -823,7 +823,23 @@ class NativeSparseAttnBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
# Do absorbed multi-latent attention
# Detect MHA mode: multi KV heads (vs MLA with single KV head)
is_mha_mode = (layer.tp_k_head_num == layer.tp_q_head_num) and (
layer.tp_k_head_num > 1
)
# Use MHA kernel if in MHA_ONE_SHOT mode
if is_mha_mode and k is not None and v is not None and q_rope is None:
return self._forward_standard_mha(
q=q,
k=k,
v=v,
layer=layer,
forward_batch=forward_batch,
metadata=metadata,
)
# Do absorbed multi-latent attention (MLA path)
assert q_rope is not None
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......@@ -1154,6 +1170,49 @@ class NativeSparseAttnBackend(AttentionBackend):
)
return o
def _forward_standard_mha(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
metadata: NSAMetadata,
) -> torch.Tensor:
"""Standard MHA using FlashAttention varlen for MHA_ONE_SHOT mode."""
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_v_head_num, layer.v_head_dim)
# MHA_ONE_SHOT: k/v include all tokens (prefix + current)
cu_seqlens_q = metadata.cu_seqlens_q
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_k = metadata.max_seq_len_k
causal = True
# Verify batch sizes match (length of cu_seqlens should be batch_size + 1)
assert len(cu_seqlens_q) == len(cu_seqlens_k), (
f"batch_size mismatch: cu_seqlens_q has {len(cu_seqlens_q)-1} requests, "
f"cu_seqlens_k has {len(cu_seqlens_k)-1} requests"
)
# Determine FA version: FA3 for SM90 (Hopper), FA4 for SM100+ (Blackwell and beyond)
device_sm_major = torch.cuda.get_device_capability()[0]
fa_version = 4 if device_sm_major >= 10 else 3
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=layer.scaling,
causal=causal,
ver=fa_version,
)
def _forward_tilelang(
self,
q_all: torch.Tensor,
......
......@@ -398,6 +398,34 @@ def handle_attention_aiter(attn, forward_batch):
def handle_attention_nsa(attn, forward_batch):
"""
Select MHA or MLA based on sequence length for optimal performance.
- Decode: MLA (avoids per-token decompression)
- Prefill <= 2048: MHA (topk ineffective, MHA has lower FLOPs)
- Prefill > 2048: MLA (topk filtering reduces computation significantly)
TODO: B200 (SM100) MHA path is temporarily disabled due to FA4 gpqa accuracy issues.
"""
if forward_batch.forward_mode.is_decode_or_idle():
return AttnForwardMethod.MLA
if _is_extend_without_speculative(forward_batch):
assert forward_batch.seq_lens_cpu is not None
max_kv_len = forward_batch.seq_lens_cpu.max().item()
# B200 (SM100) is temporarily disabled for MHA due to FA4 accuracy issues
# Currently only H200 (SM90) with FA3 is allowed to use MHA path
is_hopper = _device_sm == 90
if max_kv_len <= attn.indexer.index_topk and is_hopper:
# NSA backend uses varlen kernel which supports MHA_ONE_SHOT
# Check if total sequence length fits in chunk capacity
sum_seq_lens = sum(forward_batch.seq_lens_cpu)
# Use MHA_ONE_SHOT for best performance
if sum_seq_lens <= forward_batch.get_max_chunk_capacity():
return AttnForwardMethod.MHA_ONE_SHOT
return AttnForwardMethod.MLA
......@@ -1466,8 +1494,21 @@ class DeepseekV2AttentionMLA(nn.Module):
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
q_lora = self.q_a_layernorm(q)
q = self.q_b_proj(q_lora)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
# NSA Indexer: cache quantized keys, auto-skip topk for sequences <= nsa_index_topk
if self.use_nsa and _is_extend_without_speculative(forward_batch):
_ = self.indexer(
x=hidden_states,
q_lora=q_lora,
positions=positions,
forward_batch=forward_batch,
layer_id=self.layer_id,
return_indices=False,
)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment