"vscode:/vscode.git/clone" did not exist on "ad9d7ce4763f8fb2a9e620bff017830c26086c36"
Unverified Commit 4793ec7d authored by Yongfei Xu's avatar Yongfei Xu Committed by GitHub
Browse files

Opt MHA chunked prefix: merge prefix and extend kv cache to run mha once (#10953)

parent 92009bd2
...@@ -855,14 +855,24 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -855,14 +855,24 @@ class FlashAttentionBackend(AttentionBackend):
) )
else: else:
# MHA for extend part of sequence without attending prefix kv cache # MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k = (
metadata.cu_seqlens_q
if not forward_batch.mha_one_shot
else metadata.cu_seqlens_k
)
max_seqlen_k = (
metadata.max_seq_len_q
if not forward_batch.mha_one_shot
else metadata.max_seq_len_k
)
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim), q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q, max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q, max_seqlen_k=max_seqlen_k,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
return_softmax_lse=forward_batch.mha_return_lse, return_softmax_lse=forward_batch.mha_return_lse,
......
...@@ -82,6 +82,7 @@ class FlashInferMhaChunkKVRunner: ...@@ -82,6 +82,7 @@ class FlashInferMhaChunkKVRunner:
# Buffers and wrappers # Buffers and wrappers
self.qo_indptr = attn_backend.qo_indptr self.qo_indptr = attn_backend.qo_indptr
self.kv_indptr = attn_backend.kv_indptr
self.workspace_buffer = attn_backend.workspace_buffer self.workspace_buffer = attn_backend.workspace_buffer
self.fmha_backend = attn_backend.fmha_backend self.fmha_backend = attn_backend.fmha_backend
...@@ -132,9 +133,14 @@ class FlashInferMhaChunkKVRunner: ...@@ -132,9 +133,14 @@ class FlashInferMhaChunkKVRunner:
) )
# ragged prefill # ragged prefill
if not disable_flashinfer_ragged: if not disable_flashinfer_ragged:
kv_indptr = (
qo_indptr
if not forward_batch.mha_one_shot
else self.kv_indptr[: bs + 1]
)
self.ragged_wrapper.begin_forward( self.ragged_wrapper.begin_forward(
qo_indptr=qo_indptr, qo_indptr=qo_indptr,
kv_indptr=qo_indptr, kv_indptr=kv_indptr,
num_qo_heads=self.num_local_heads, num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
...@@ -156,7 +162,7 @@ class FlashInferMhaChunkKVRunner: ...@@ -156,7 +162,7 @@ class FlashInferMhaChunkKVRunner:
chunk_idx = forward_batch.prefix_chunk_idx chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0 assert chunk_idx >= 0
wrapper = self.chunk_ragged_wrappers[chunk_idx] wrapper = self.chunk_ragged_wrappers[chunk_idx]
o1, s1 = wrapper.forward_return_lse( o = wrapper.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim), q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype), v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
...@@ -165,7 +171,12 @@ class FlashInferMhaChunkKVRunner: ...@@ -165,7 +171,12 @@ class FlashInferMhaChunkKVRunner:
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
) )
else: else:
o1, s1 = self.ragged_wrapper.forward_return_lse( forward = (
self.ragged_wrapper.forward_return_lse
if forward_batch.mha_return_lse
else self.ragged_wrapper.forward
)
o = forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim), q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype), v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
...@@ -173,8 +184,7 @@ class FlashInferMhaChunkKVRunner: ...@@ -173,8 +184,7 @@ class FlashInferMhaChunkKVRunner:
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
) )
return o
return o1, s1
class FlashInferMLAAttnBackend(AttentionBackend): class FlashInferMLAAttnBackend(AttentionBackend):
...@@ -512,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -512,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
): ):
if ( if forward_batch.attn_attend_prefix_cache is not None and any(
forward_batch.attn_attend_prefix_cache is not None forward_batch.extend_prefix_lens_cpu
and forward_batch.mha_return_lse
): # MHA Chunk ): # MHA Chunk
assert self.enable_chunk_kv assert self.enable_chunk_kv
assert q_rope is None assert q_rope is None
assert k_rope is None assert k_rope is None
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch) return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
return o1, s1
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap logits_soft_cap = layer.logit_cap
......
import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton( ...@@ -101,3 +102,80 @@ def create_flashmla_kv_indices_triton(
data // PAGED_SIZE, data // PAGED_SIZE,
mask=mask_out, mask=mask_out,
) )
@triton.jit
def concat_and_cast_mha_k_kernel(
k_ptr,
k_nope_ptr,
k_rope_ptr,
head_cnt: tl.constexpr,
k_stride0: tl.constexpr,
k_stride1: tl.constexpr,
nope_stride0: tl.constexpr,
nope_stride1: tl.constexpr,
rope_stride0: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
):
pid_loc = tl.program_id(0)
head_range = tl.arange(0, head_cnt)
k_head_ptr = k_ptr + pid_loc * k_stride0 + head_range[:, None] * k_stride1
nope_offs = tl.arange(0, nope_dim)
src_nope_ptr = (
k_nope_ptr
+ pid_loc * nope_stride0
+ head_range[:, None] * nope_stride1
+ nope_offs[None, :]
)
dst_nope_ptr = k_head_ptr + nope_offs[None, :]
src_nope = tl.load(src_nope_ptr)
tl.store(dst_nope_ptr, src_nope)
rope_offs = tl.arange(0, rope_dim)
src_rope_ptr = k_rope_ptr + pid_loc * rope_stride0 + rope_offs[None, :]
dst_rope_ptr = k_head_ptr + nope_dim + rope_offs[None, :]
src_rope = tl.load(src_rope_ptr)
tl.store(dst_rope_ptr, src_rope)
def concat_and_cast_mha_k_triton(
k: torch.Tensor,
k_nope: torch.Tensor,
k_rope: torch.Tensor,
):
# The source data type will be implicitly converted to the target data type.
assert (
len(k.shape) == 3 and len(k_nope.shape) == 3 and len(k_rope.shape) == 3
), f"shape should be 3d, but got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
assert (
k.shape[0] == k_nope.shape[0] and k.shape[0] == k_rope.shape[0]
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
assert (
k.shape[1] == k_nope.shape[1] and 1 == k_rope.shape[1]
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
assert (
k.shape[-1] == k_nope.shape[-1] + k_rope.shape[-1]
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
nope_dim = k_nope.shape[-1]
rope_dim = k_rope.shape[-1]
grid = (k.shape[0],)
concat_and_cast_mha_k_kernel[grid](
k,
k_nope,
k_rope,
k.shape[1],
k.stride(0),
k.stride(1),
k_nope.stride(0),
k_nope.stride(1),
k_rope.stride(0),
nope_dim,
rope_dim,
)
...@@ -1213,6 +1213,65 @@ def set_mla_kv_buffer_triton( ...@@ -1213,6 +1213,65 @@ def set_mla_kv_buffer_triton(
) )
@triton.jit
def get_mla_kv_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
):
pid_loc = tl.program_id(0)
loc = tl.load(loc_ptr + pid_loc)
loc_src_ptr = kv_buffer_ptr + loc * buffer_stride
nope_offs = tl.arange(0, nope_dim)
nope_src_ptr = loc_src_ptr + nope_offs
nope_src = tl.load(nope_src_ptr)
tl.store(
cache_k_nope_ptr + pid_loc * nope_stride + nope_offs,
nope_src,
)
rope_offs = tl.arange(0, rope_dim)
rope_src_ptr = loc_src_ptr + nope_dim + rope_offs
rope_src = tl.load(rope_src_ptr)
tl.store(
cache_k_rope_ptr + pid_loc * rope_stride + rope_offs,
rope_src,
)
def get_mla_kv_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
# The source data type will be implicitly converted to the target data type.
nope_dim = cache_k_nope.shape[-1] # 512
rope_dim = cache_k_rope.shape[-1] # 64
n_loc = loc.numel()
grid = (n_loc,)
get_mla_kv_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
)
class MLATokenToKVPool(KVCache): class MLATokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
...@@ -1363,6 +1422,29 @@ class MLATokenToKVPool(KVCache): ...@@ -1363,6 +1422,29 @@ class MLATokenToKVPool(KVCache):
cache_k_rope, cache_k_rope,
) )
def get_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
dst_dtype: Optional[torch.dtype] = None,
):
# get k nope and k rope from the kv buffer, and optionally cast them to dst_dtype.
layer_id = layer.layer_id
kv_buffer = self.get_key_buffer(layer_id)
dst_dtype = dst_dtype or self.dtype
cache_k_nope = torch.empty(
(loc.shape[0], 1, self.kv_lora_rank),
dtype=dst_dtype,
device=kv_buffer.device,
)
cache_k_rope = torch.empty(
(loc.shape[0], 1, self.qk_rope_head_dim),
dtype=dst_dtype,
device=kv_buffer.device,
)
get_mla_kv_buffer_triton(kv_buffer, loc, cache_k_nope, cache_k_rope)
return cache_k_nope, cache_k_rope
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
torch.cuda.synchronize() torch.cuda.synchronize()
kv_cache_cpu = [] kv_cache_cpu = []
......
...@@ -39,6 +39,7 @@ import triton ...@@ -39,6 +39,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
DpPaddingMode, DpPaddingMode,
get_attention_dp_rank, get_attention_dp_rank,
...@@ -250,6 +251,8 @@ class ForwardBatch: ...@@ -250,6 +251,8 @@ class ForwardBatch:
# For MLA chunked prefix cache used in chunked prefill # For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether lse needs to be returned # Tell attention backend whether lse needs to be returned
mha_return_lse: Optional[bool] = None mha_return_lse: Optional[bool] = None
mha_one_shot_kv_indices: Optional[torch.Tensor] = None
mha_one_shot: Optional[bool] = None
# For multimodal # For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None mm_inputs: Optional[List[MultimodalInputs]] = None
...@@ -863,6 +866,10 @@ class ForwardBatch: ...@@ -863,6 +866,10 @@ class ForwardBatch:
self.token_to_kv_pool, MLATokenToKVPool self.token_to_kv_pool, MLATokenToKVPool
), "Currently chunked prefix cache can only be used by Deepseek models" ), "Currently chunked prefix cache can only be used by Deepseek models"
if not any(self.extend_prefix_lens_cpu):
self.num_prefix_chunks = 0
return
if self.prefix_chunk_len is not None: if self.prefix_chunk_len is not None:
# Chunked kv cache info already prepared by prior modules # Chunked kv cache info already prepared by prior modules
return return
...@@ -917,6 +924,34 @@ class ForwardBatch: ...@@ -917,6 +924,34 @@ class ForwardBatch:
def can_run_tbo(self): def can_run_tbo(self):
return self.tbo_split_seq_index is not None return self.tbo_split_seq_index is not None
def fetch_mha_one_shot_kv_indices(self):
if self.mha_one_shot_kv_indices is not None:
return self.mha_one_shot_kv_indices
batch_size = self.batch_size
paged_kernel_lens_sum = sum(self.seq_lens_cpu)
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=self.req_pool_indices.device,
)
kv_indptr = torch.zeros(
batch_size + 1,
dtype=torch.int32,
device=self.req_pool_indices.device,
)
kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
create_flashinfer_kv_indices_triton[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
self.seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
self.mha_one_shot_kv_indices = kv_indices
return kv_indices
def enable_num_token_non_padded(server_args): def enable_num_token_non_padded(server_args):
return get_moe_expert_parallel_world_size() > 1 return get_moe_expert_parallel_world_size() > 1
......
...@@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import ( ...@@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
is_mla_preprocess_enabled, is_mla_preprocess_enabled,
) )
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
LayerCommunicator, LayerCommunicator,
LayerScatterModes, LayerScatterModes,
...@@ -241,6 +242,10 @@ class AttnForwardMethod(IntEnum): ...@@ -241,6 +242,10 @@ class AttnForwardMethod(IntEnum):
# This method can avoid OOM when prefix lengths are long. # This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto() MHA_CHUNKED_KV = auto()
# Use multi-head attention, execute the MHA for prefix and extended kv in one shot
# when the sequence lengths are below the threshold.
MHA_ONE_SHOT = auto()
# Use MLA but with fused RoPE # Use MLA but with fused RoPE
MLA_FUSED_ROPE = auto() MLA_FUSED_ROPE = auto()
...@@ -306,6 +311,14 @@ def _is_extend_without_speculative(forward_batch): ...@@ -306,6 +311,14 @@ def _is_extend_without_speculative(forward_batch):
) )
def _support_mha_one_shot(attn: DeepseekV2AttentionMLA, forward_batch, backend_name):
attn_supported = backend_name in ["fa3", "flashinfer", "flashmla"]
sum_seq_lens = (
sum(forward_batch.seq_lens_cpu) if forward_batch.seq_lens_cpu is not None else 0
)
return attn_supported and sum_seq_lens <= forward_batch.get_max_chunk_capacity()
def _handle_attention_backend( def _handle_attention_backend(
attn: DeepseekV2AttentionMLA, forward_batch, backend_name attn: DeepseekV2AttentionMLA, forward_batch, backend_name
): ):
...@@ -325,6 +338,8 @@ def _handle_attention_backend( ...@@ -325,6 +338,8 @@ def _handle_attention_backend(
or sum_extend_prefix_lens == 0 or sum_extend_prefix_lens == 0
) )
): ):
if _support_mha_one_shot(attn, forward_batch, backend_name):
return AttnForwardMethod.MHA_ONE_SHOT
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
return _dispatch_mla_subtype(attn, forward_batch) return _dispatch_mla_subtype(attn, forward_batch)
...@@ -1062,6 +1077,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1062,6 +1077,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.kv_cache_dtype = get_global_server_args().kv_cache_dtype
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
if rope_scaling: if rope_scaling:
...@@ -1359,6 +1375,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1359,6 +1375,10 @@ class DeepseekV2AttentionMLA(nn.Module):
inner_state = self.forward_normal_chunked_kv_prepare( inner_state = self.forward_normal_chunked_kv_prepare(
positions, hidden_states, forward_batch, zero_allocator positions, hidden_states, forward_batch, zero_allocator
) )
elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
inner_state = self.forward_normal_one_shot_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA: elif attn_forward_method == AttnForwardMethod.MLA:
if not self.is_mla_preprocess_enabled: if not self.is_mla_preprocess_enabled:
inner_state = self.forward_absorb_prepare( inner_state = self.forward_absorb_prepare(
...@@ -1410,6 +1430,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1410,6 +1430,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_normal_core(*inner_state) return self.forward_normal_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV: elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
return self.forward_normal_chunked_kv_core(*inner_state) return self.forward_normal_chunked_kv_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
return self.forward_normal_one_shot_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA: elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb_core(*inner_state) return self.forward_absorb_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE: elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
...@@ -1444,41 +1466,24 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1444,41 +1466,24 @@ class DeepseekV2AttentionMLA(nn.Module):
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1) latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a) kv_a = self.kv_a_layernorm(kv_a)
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k_pe = latent_cache[:, :, self.kv_lora_rank :] k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
# Temporary for DeepSeek V3/R1 only, but can generalize if needed self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
if ( if (
_is_cuda forward_batch.mha_one_shot
and (self.num_local_heads == 128) and sum(forward_batch.extend_prefix_lens_cpu) != 0
and (self.qk_nope_head_dim == 128)
and (self.qk_rope_head_dim == 64)
): ):
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe) kv_a, k_pe = self._get_mla_kv_buffer(
else: forward_batch.fetch_mha_one_shot_kv_indices(), q.dtype, forward_batch
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
if not _is_npu:
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
else:
# To reduce a time-costing split operation
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
) )
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k = self._concat_and_cast_mha_k(k_nope, k_pe, forward_batch)
return q, k, v, forward_batch return q, k, v, forward_batch
def forward_normal_core(self, q, k, v, forward_batch): def forward_normal_core(self, q, k, v, forward_batch):
...@@ -2288,20 +2293,11 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -2288,20 +2293,11 @@ class DeepseekV2AttentionMLA(nn.Module):
for i in range(forward_batch.num_prefix_chunks): for i in range(forward_batch.num_prefix_chunks):
forward_batch.set_prefix_chunk_idx(i) forward_batch.set_prefix_chunk_idx(i)
kv_indices = forward_batch.prefix_chunk_kv_indices[i]
# Fetch latent cache from memory pool with precomputed chunked kv indices # Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer( kv_a_normed, k_pe = self._get_mla_kv_buffer(
self.attn_mha.layer_id kv_indices, q.dtype, forward_batch
)
latent_cache = (
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
.contiguous()
.to(q.dtype)
)
kv_a_normed, k_pe = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
) )
kv_a_normed = kv_a_normed.squeeze(1).contiguous()
kv = self.kv_b_proj(kv_a_normed)[0] kv = self.kv_b_proj(kv_a_normed)[0]
kv = kv.view( kv = kv.view(
-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
...@@ -2376,6 +2372,107 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -2376,6 +2372,107 @@ class DeepseekV2AttentionMLA(nn.Module):
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
def forward_normal_one_shot_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
forward_batch.mha_one_shot = True
return self.forward_normal_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
def forward_normal_one_shot_core(self, q, k, v, forward_batch):
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
# Only initialize the info once
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
forward_batch.num_prefix_chunks = 0
if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
forward_batch.mha_return_lse = False
# Do mha for extended part without prefix
forward_batch.set_attn_attend_prefix_cache(False)
return self.forward_normal_core(q, k, v, forward_batch)
def _set_mla_kv_buffer(
self,
latent_cache: torch.Tensor,
kv_a: torch.Tensor,
k_pe: torch.Tensor,
forward_batch: ForwardBatch,
):
if _is_cuda:
# Save latent cache
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
)
elif _is_npu:
# To reduce a time-costing split operation
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
)
else:
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
def _get_mla_kv_buffer(
self,
kv_indices: torch.Tensor,
dst_dtype: torch.dtype,
forward_batch: ForwardBatch,
):
if _is_cuda:
kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer(
self.attn_mha, kv_indices, dst_dtype
)
kv_a = kv_a.squeeze(1)
else:
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
self.attn_mha.layer_id
)
latent_cache = latent_cache_buf[kv_indices].contiguous().to(dst_dtype)
kv_a, k_pe = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
kv_a = kv_a.squeeze(1).contiguous()
return kv_a, k_pe
def _concat_and_cast_mha_k(self, k_nope, k_pe, forward_batch):
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
k_shape = (k_nope.shape[0], self.num_local_heads, self.qk_head_dim)
if (
_is_cuda
and (self.num_local_heads == 128)
and (self.qk_nope_head_dim == 128)
and (self.qk_rope_head_dim == 64)
):
k = k_nope.new_empty(*k_shape)
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
elif _is_cuda:
# fa3 mha support fp8 inputs
if (
self.current_attention_backend == "fa3"
and self.kv_cache_dtype != "auto"
):
attn_dtype = forward_batch.token_to_kv_pool.dtype
else:
attn_dtype = k_nope.dtype
k = k_nope.new_empty(*k_shape, dtype=attn_dtype)
concat_and_cast_mha_k_triton(k, k_nope, k_pe)
else:
k = k_nope.new_empty(*k_shape)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
return k
@staticmethod @staticmethod
def _get_q_b_proj_quant_config(quant_config): def _get_q_b_proj_quant_config(quant_config):
if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"): if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment