Unverified Commit f235498e authored by YAMY's avatar YAMY Committed by GitHub
Browse files

DeepSeek-V3.2: Add Adaptive MHA Attention Pathway for Short-Sequence Prefill (#11892)

parent 149dc9aa
...@@ -242,6 +242,30 @@ class Indexer(CustomOp): ...@@ -242,6 +242,30 @@ class Indexer(CustomOp):
return query, key, weights return query, key, weights
def _get_k_bf16(
self,
x: torch.Tensor,
positions: torch.Tensor,
enable_dual_stream: bool,
):
# Compute only key, skip query and weights (weights is discarded if fused)
if self.fuse_wk_and_weights_proj:
key, _ = self.fused_wk_and_weights_proj(x)[0].split(
[self.head_dim, self.n_heads], dim=-1
)
else:
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
_, k_rope = self.rotary_emb(positions, k_rope, k_rope)
key[..., : self.rope_head_dim] = k_rope
key = rotate_activation(key)
return key
def _get_topk_paged( def _get_topk_paged(
self, self,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
...@@ -375,6 +399,45 @@ class Indexer(CustomOp): ...@@ -375,6 +399,45 @@ class Indexer(CustomOp):
topk_result[:offset] = raw_topk_result topk_result[:offset] = raw_topk_result
return topk_result return topk_result
def _forward_cuda_k_only(
self,
x: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
act_quant,
enable_dual_stream: bool,
metadata: BaseIndexerMetadata,
return_indices: bool = True,
) -> Optional[torch.Tensor]:
# Fast path: only compute and store k cache, skip all q and weights ops
key = self._get_k_bf16(x, positions, enable_dual_stream)
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
if not forward_batch.out_cache_loc.is_contiguous():
forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
index_k=k_fp8,
index_k_scale=k_scale,
)
# MHA doesn't need topk_indices
if not return_indices:
return None
# MLA: use dummy logits with topk kernel's fast path to generate indices
# When length <= 2048, naive_topk_cuda directly generates [0,1,...,length-1,-1,...]
seq_lens_expanded = metadata.get_seqlens_expanded()
dummy_logits = torch.zeros(
seq_lens_expanded.shape[0],
self.index_topk,
dtype=torch.float32,
device=x.device,
)
return metadata.topk_transform(dummy_logits, self.index_topk)
def forward_indexer( def forward_indexer(
self, self,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
...@@ -465,6 +528,7 @@ class Indexer(CustomOp): ...@@ -465,6 +528,7 @@ class Indexer(CustomOp):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
layer_id: int, layer_id: int,
return_indices: bool = True,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if is_hip(): if is_hip():
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
...@@ -490,6 +554,26 @@ class Indexer(CustomOp): ...@@ -490,6 +554,26 @@ class Indexer(CustomOp):
if metadata is None: if metadata is None:
return None return None
# Determine if should skip topk based on sequence length
should_skip = False
if not forward_batch.forward_mode.is_decode_or_idle():
if forward_batch.seq_lens_cpu is not None:
max_kv_len = forward_batch.seq_lens_cpu.max().item()
should_skip = max_kv_len <= self.index_topk
# Optimization: fast path when skipping topk computation
if should_skip:
return self._forward_cuda_k_only(
x,
positions,
forward_batch,
layer_id,
act_quant,
enable_dual_stream,
metadata,
return_indices,
)
query, key, weights = self._get_q_k_bf16( query, key, weights = self._get_q_k_bf16(
q_lora, x, positions, enable_dual_stream q_lora, x, positions, enable_dual_stream
) )
......
...@@ -47,7 +47,7 @@ if _is_hip: ...@@ -47,7 +47,7 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
) )
else: else:
from sgl_kernel.flash_attn import flash_attn_with_kvcache from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -823,7 +823,23 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -823,7 +823,23 @@ class NativeSparseAttnBackend(AttentionBackend):
# For fa3 interface version compatibility, we put new fields into conditional keyword args # For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {} kwargs = {}
# Do absorbed multi-latent attention # Detect MHA mode: multi KV heads (vs MLA with single KV head)
is_mha_mode = (layer.tp_k_head_num == layer.tp_q_head_num) and (
layer.tp_k_head_num > 1
)
# Use MHA kernel if in MHA_ONE_SHOT mode
if is_mha_mode and k is not None and v is not None and q_rope is None:
return self._forward_standard_mha(
q=q,
k=k,
v=v,
layer=layer,
forward_batch=forward_batch,
metadata=metadata,
)
# Do absorbed multi-latent attention (MLA path)
assert q_rope is not None assert q_rope is not None
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
...@@ -1154,6 +1170,49 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -1154,6 +1170,49 @@ class NativeSparseAttnBackend(AttentionBackend):
) )
return o return o
def _forward_standard_mha(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
metadata: NSAMetadata,
) -> torch.Tensor:
"""Standard MHA using FlashAttention varlen for MHA_ONE_SHOT mode."""
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_v_head_num, layer.v_head_dim)
# MHA_ONE_SHOT: k/v include all tokens (prefix + current)
cu_seqlens_q = metadata.cu_seqlens_q
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_k = metadata.max_seq_len_k
causal = True
# Verify batch sizes match (length of cu_seqlens should be batch_size + 1)
assert len(cu_seqlens_q) == len(cu_seqlens_k), (
f"batch_size mismatch: cu_seqlens_q has {len(cu_seqlens_q)-1} requests, "
f"cu_seqlens_k has {len(cu_seqlens_k)-1} requests"
)
# Determine FA version: FA3 for SM90 (Hopper), FA4 for SM100+ (Blackwell and beyond)
device_sm_major = torch.cuda.get_device_capability()[0]
fa_version = 4 if device_sm_major >= 10 else 3
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=layer.scaling,
causal=causal,
ver=fa_version,
)
def _forward_tilelang( def _forward_tilelang(
self, self,
q_all: torch.Tensor, q_all: torch.Tensor,
......
...@@ -398,6 +398,34 @@ def handle_attention_aiter(attn, forward_batch): ...@@ -398,6 +398,34 @@ def handle_attention_aiter(attn, forward_batch):
def handle_attention_nsa(attn, forward_batch): def handle_attention_nsa(attn, forward_batch):
"""
Select MHA or MLA based on sequence length for optimal performance.
- Decode: MLA (avoids per-token decompression)
- Prefill <= 2048: MHA (topk ineffective, MHA has lower FLOPs)
- Prefill > 2048: MLA (topk filtering reduces computation significantly)
TODO: B200 (SM100) MHA path is temporarily disabled due to FA4 gpqa accuracy issues.
"""
if forward_batch.forward_mode.is_decode_or_idle():
return AttnForwardMethod.MLA
if _is_extend_without_speculative(forward_batch):
assert forward_batch.seq_lens_cpu is not None
max_kv_len = forward_batch.seq_lens_cpu.max().item()
# B200 (SM100) is temporarily disabled for MHA due to FA4 accuracy issues
# Currently only H200 (SM90) with FA3 is allowed to use MHA path
is_hopper = _device_sm == 90
if max_kv_len <= attn.indexer.index_topk and is_hopper:
# NSA backend uses varlen kernel which supports MHA_ONE_SHOT
# Check if total sequence length fits in chunk capacity
sum_seq_lens = sum(forward_batch.seq_lens_cpu)
# Use MHA_ONE_SHOT for best performance
if sum_seq_lens <= forward_batch.get_max_chunk_capacity():
return AttnForwardMethod.MHA_ONE_SHOT
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
...@@ -1466,8 +1494,21 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1466,8 +1494,21 @@ class DeepseekV2AttentionMLA(nn.Module):
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
) )
q = self.q_a_layernorm(q) q_lora = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) q = self.q_b_proj(q_lora)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
# NSA Indexer: cache quantized keys, auto-skip topk for sequences <= nsa_index_topk
if self.use_nsa and _is_extend_without_speculative(forward_batch):
_ = self.indexer(
x=hidden_states,
q_lora=q_lora,
positions=positions,
forward_batch=forward_batch,
layer_id=self.layer_id,
return_indices=False,
)
else: else:
q = self.q_proj(hidden_states)[0].view( q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim -1, self.num_local_heads, self.qk_head_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment