Unverified Commit a42736bb authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Support MHA with chunked prefix cache for DeepSeek chunked prefill (#5113)

parent dd83e7e9
......@@ -195,3 +195,4 @@ Please consult the documentation below to learn more about the parameters you ma
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.**
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.
......@@ -92,13 +92,15 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes.
- **Chunked Prefix Cache**: Chunked prefix cache optimization can increase throughput by cutting prefix cache into chunks, processing them with multi-head attention and merging their states. Its improvement can be significant when doing chunked prefill on long sequences. Currently this optimization is only available for FlashAttention3 backend.
Overall, with these optimizations, we have achieved up to **7x** acceleration in output throughput compared to the previous version.
<p align="center">
<img src="https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg" alt="Multi-head Latent Attention for DeepSeek Series Models">
</p>
**Usage**: MLA optimization is enabled by default, to disable, use `--disable-mla`.
**Usage**: MLA optimization is enabled by default. To disable MLA usage, use `--disable-mla`. To disable chunked prefix cache feature for mla, use `disable-chunked-prefix-cache`.
**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details.
......
......@@ -16,7 +16,7 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass
......@@ -593,41 +593,87 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
if (
not global_server_args_dict["disable_chunked_prefix_cache"]
and forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
output, lse, *rest = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
output, lse, *rest = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=True,
)
return output, lse
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
......
......@@ -83,6 +83,7 @@ global_server_args_dict = {
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
}
logger = logging.getLogger(__name__)
......
......@@ -181,6 +181,28 @@ class ForwardBatch:
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether the kv cache needs to be attended in current pass
attn_attend_prefix_cache: Optional[bool] = None
# Number of prefix cache chunks
num_prefix_chunks: Optional[int] = None
# Index of current chunk, used by attention backend
prefix_chunk_idx: Optional[int] = None
# Maximum number of tokens in each chunk per sequence. Computed from maximum chunk capacity
prefix_chunk_len: Optional[int] = None
# Start positions of prefix cache for each chunk, (num_prefix_chunks, batch_size)
prefix_chunk_starts: Optional[torch.Tensor] = None
# Lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size)
prefix_chunk_seq_lens: Optional[torch.Tensor] = None
# Accumulated lengths of prefix cache for each chunk, (num_prefix_chunks, batch_size + 1)
prefix_chunk_cu_seq_lens: Optional[torch.Tensor] = None
# Max lengths of prefix cache for each chunk, (num_prefix_chunks,)
prefix_chunk_max_seq_lens: Optional[List[int]] = None
# Number of tokens in each prefix cache chunk, (num_prefix_chunks,)
prefix_chunk_num_tokens: Optional[List[int]] = None
# KV Indices for each chunk
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
# For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None
......@@ -484,6 +506,128 @@ class ForwardBatch:
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def get_max_chunk_capacity(self):
# Maximum number of tokens in each chunk
# TODO: Should be changed to a better value, maybe passed through server args
return 128 * 1024
def set_prefix_chunk_idx(self, idx: int):
self.prefix_chunk_idx = idx
def set_attn_attend_prefix_cache(self, attn_attend_prefix_cache: bool):
self.attn_attend_prefix_cache = attn_attend_prefix_cache
def prepare_chunked_kv_indices(self, device: torch.device):
self.prefix_chunk_kv_indices = []
for idx in range(self.num_prefix_chunks):
chunk_starts = self.prefix_chunk_starts[idx]
chunk_seq_lens = self.prefix_chunk_seq_lens[idx]
chunk_cu_seq_lens = self.prefix_chunk_cu_seq_lens[idx]
num_chunk_tokens = self.prefix_chunk_num_tokens[idx]
chunk_kv_indices = torch.empty(
num_chunk_tokens, dtype=torch.int32, device=device
)
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
chunk_starts,
chunk_seq_lens,
chunk_cu_seq_lens,
chunk_kv_indices,
self.req_to_token_pool.req_to_token.shape[1],
)
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
# Here we suppose the length of each chunk is equal
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
# num_prefix_chunks = cdiv(1024, 256) = 4
# prefix_chunk_starts = [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512], [768, 768, 768, 768]]
# prefix_chunk_ends = [[256, 256, 256, 256], [256, 512, 512, 512], [256, 512, 768, 768], [256, 512, 768, 1024]]
# prefix_chunk_seq_lens = [[256, 256, 256, 256], [0, 256, 256, 256], [0, 0, 256, 256], [0, 0, 0, 256]]
# TODO: Implement a better way to allocate chunk lengths that uses memory spaces more efficiently.
def get_prefix_chunk_seq_lens(
self, prefix_lens: torch.Tensor, num_prefix_chunks: int, prefix_chunk_len: int
):
device = prefix_lens.device
prefix_chunk_starts = (
torch.arange(num_prefix_chunks, device=device, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, self.batch_size)
* prefix_chunk_len
)
prefix_chunk_ends = torch.min(
prefix_lens.unsqueeze(0),
prefix_chunk_starts + prefix_chunk_len,
).to(torch.int32)
prefix_chunk_seq_lens = (
(prefix_chunk_ends - prefix_chunk_starts).clamp(min=0).to(torch.int32)
)
return prefix_chunk_starts, prefix_chunk_seq_lens
# Called before each attention module if using chunked kv cache for prefill
# Some of the codes are adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
def prepare_chunked_prefix_cache_info(self, device: torch.device):
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
assert isinstance(
self.token_to_kv_pool, MLATokenToKVPool
), "Currently chunked prefix cache can only be used by Deepseek models"
if self.prefix_chunk_len is not None:
# Chunked kv cache info already prepared by prior modules
return
self.prefix_chunk_idx = -1
# chunk_capacity is the maximum number of tokens in each chunk
chunk_capacity = self.get_max_chunk_capacity()
self.prefix_chunk_len = chunk_capacity // self.batch_size
self.num_prefix_chunks = (
max(self.extend_prefix_lens_cpu) + self.prefix_chunk_len - 1
) // self.prefix_chunk_len
# Here we compute chunk lens twice to avoid stream sync, once on gpu and once on cpu.
prefix_chunk_starts_cuda, prefix_chunk_seq_lens_cuda = (
self.get_prefix_chunk_seq_lens(
self.extend_prefix_lens,
self.num_prefix_chunks,
self.prefix_chunk_len,
)
)
_, prefix_chunk_seq_lens_cpu = self.get_prefix_chunk_seq_lens(
torch.tensor(self.extend_prefix_lens_cpu),
self.num_prefix_chunks,
self.prefix_chunk_len,
)
self.prefix_chunk_starts = prefix_chunk_starts_cuda
self.prefix_chunk_seq_lens = prefix_chunk_seq_lens_cuda
# Metadata for attention backend
self.prefix_chunk_cu_seq_lens = torch.zeros(
self.num_prefix_chunks,
self.batch_size + 1,
device=device,
dtype=torch.int32,
)
self.prefix_chunk_cu_seq_lens[:, 1:] = prefix_chunk_seq_lens_cuda.cumsum(
dim=1
).to(torch.int32)
self.prefix_chunk_max_seq_lens = prefix_chunk_seq_lens_cpu.max(
dim=1
).values.tolist()
self.prefix_chunk_num_tokens = prefix_chunk_seq_lens_cpu.sum(dim=1).tolist()
assert max(self.prefix_chunk_num_tokens) <= self.get_max_chunk_capacity()
# Precompute the kv indices for each chunk
self.prepare_chunked_kv_indices(device)
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
......@@ -561,3 +705,40 @@ def compute_position_torch(
@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
@triton.jit
def create_chunked_prefix_cache_kv_indices(
req_to_token_ptr, # (max_batch, max_context_len,)
req_pool_indices_ptr, # (batch_size,)
chunk_start_idx_ptr, # (batch_size,)
chunk_seq_lens_ptr, # (batch_size,)
chunk_cu_seq_lens_ptr, # (batch_size + 1,)
chunk_kv_indices_ptr, # (num_chunk_tokens,)
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
# find the req pool idx, this is for batch to token
req_pool_index = tl.load(req_pool_indices_ptr + pid)
chunk_kv_indices_offset = tl.load(chunk_cu_seq_lens_ptr + pid)
# get the token positions of current chunk
chunk_start_pos = tl.load(chunk_start_idx_ptr + pid).to(tl.int32)
chunk_seq_len = tl.load(chunk_seq_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(chunk_seq_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < chunk_seq_len
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ chunk_start_pos
+ offset,
mask=mask,
)
tl.store(
chunk_kv_indices_ptr + chunk_kv_indices_offset + offset, data, mask=mask
)
......@@ -167,6 +167,7 @@ class ModelRunner:
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
"use_mla_backend": self.use_mla_backend,
}
)
......@@ -318,6 +319,16 @@ class ModelRunner:
if server_args.enable_deepep_moe:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
if not self.use_mla_backend:
logger.info("Disable chunked prefix cache for non-MLA backend.")
server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache:
logger.info("Chunked prefix cache is turned on.")
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
......
......@@ -18,6 +18,7 @@
import logging
import os
from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
......@@ -78,7 +79,7 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
else:
......@@ -94,6 +95,19 @@ expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
class AttnForwardMethod(IntEnum):
# Use multi-head attention
MHA = auto()
# Use absorbed multi-latent attention
MLA = auto()
# Use multi-head attention, but with KV cache chunked.
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto()
class DeepseekV2MLP(nn.Module):
def __init__(
self,
......@@ -694,30 +708,54 @@ class DeepseekV2AttentionMLA(nn.Module):
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
self.attention_backend = global_server_args_dict["attention_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
# TODO: Design a finer way to determine the threshold
self.chunked_prefix_cache_threshold = 8192
def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch
) -> AttnForwardMethod:
if self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
if (
not self.flashinfer_mla_disable_ragged
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
)
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
elif self.attention_backend == "fa3":
# Flash Attention: Keep absorbing for all extend/decode
return False
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if (
forward_batch.forward_mode.is_extend()
and not self.disable_chunked_prefix_cache
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu)
>= self.chunked_prefix_cache_threshold
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return AttnForwardMethod.MLA
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
)
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
def forward(
self,
......@@ -731,8 +769,14 @@ class DeepseekV2AttentionMLA(nn.Module):
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if self.no_absorb(forward_batch):
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
if attn_forward_method == AttnForwardMethod.MHA:
return self.forward_normal(positions, hidden_states, forward_batch)
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
return self.forward_normal_chunked_kv(
positions, hidden_states, forward_batch
)
else:
if _is_hip:
if (
......@@ -1007,6 +1051,127 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
def _chunked_prefix_attn_mha(
self,
q: torch.Tensor,
accum_output: torch.Tensor,
accum_lse: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
assert forward_batch.num_prefix_chunks is not None
for i in range(forward_batch.num_prefix_chunks):
forward_batch.set_prefix_chunk_idx(i)
# Fetch latent cache from memory pool with precomputed chunked kv indices
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
self.attn_mha.layer_id
)
latent_cache = latent_cache_buf[
forward_batch.prefix_chunk_kv_indices[i]
].contiguous()
kv_a_normed, k_pe = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
kv_a_normed = kv_a_normed.squeeze(1).contiguous()
kv = self.kv_b_proj(kv_a_normed)[0]
kv = kv.view(
-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
)
v = kv[..., self.qk_nope_head_dim :]
k_nope = kv[..., : self.qk_nope_head_dim]
k = torch.empty(
(
k_nope.shape[0],
self.num_local_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
),
dtype=v.dtype,
device=v.device,
)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
lse = torch.transpose(lse, 0, 1).contiguous()
tmp_output = torch.empty_like(accum_output)
tmp_lse = torch.empty_like(accum_lse)
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
accum_output, accum_lse = tmp_output, tmp_lse
return accum_output
def forward_normal_chunked_kv(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
# To avoid this, we split the kv cache into chunks and process them one after another.
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
# The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
# will be helpful for understanding the purpose of this function.
# First do normal mha forward to get output for extended part
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
# Do mha for extended part without prefix
forward_batch.set_attn_attend_prefix_cache(False)
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
lse = torch.transpose(lse, 0, 1).contiguous()
# Do mha attention with chunked prefix cache if there are any sequence with prefix
if any(forward_batch.extend_prefix_lens_cpu):
# Only initialize the info once
if forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
forward_batch.set_attn_attend_prefix_cache(True)
attn_output = self._chunked_prefix_attn_mha(
q=q,
accum_output=attn_output,
accum_lse=lse,
forward_batch=forward_batch,
)
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
class DeepseekV2DecoderLayer(nn.Module):
......
......@@ -186,6 +186,7 @@ class ServerArgs:
warmups: Optional[str] = None
n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
......@@ -1130,6 +1131,11 @@ class ServerArgs:
action="store_true",
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
)
parser.add_argument(
"--disable-chunked-prefix-cache",
action="store_true",
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
)
# Server warmups
parser.add_argument(
......
import unittest
import torch
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.test.test_utils import CustomTestCase
TEST_CASES = [
# Sequence with same prefix lens
{
"batch_size": 3,
"prefix_lens": [64, 64, 64],
"max_chunk_capacity": 48,
"prefix_chunk_len": 16,
"num_prefix_chunks": 4,
"prefix_chunk_starts": torch.tensor(
[
[0, 0, 0],
[16, 16, 16],
[32, 32, 32],
[48, 48, 48],
],
dtype=torch.int32,
),
"prefix_chunk_seq_lens": torch.tensor(
[
[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
[16, 16, 16],
],
dtype=torch.int32,
),
},
# Sequence with different prefix lens
{
"batch_size": 4,
"prefix_lens": [16, 32, 48, 64],
"max_chunk_capacity": 64,
"prefix_chunk_len": 16,
"num_prefix_chunks": 4,
"prefix_chunk_starts": torch.tensor(
[
[0, 0, 0, 0],
[16, 16, 16, 16],
[32, 32, 32, 32],
[48, 48, 48, 48],
],
dtype=torch.int32,
),
"prefix_chunk_seq_lens": torch.tensor(
[
[16, 16, 16, 16],
[0, 16, 16, 16],
[0, 0, 16, 16],
[0, 0, 0, 16],
],
dtype=torch.int32,
),
},
# Sequence with irregular shapes
{
"batch_size": 2,
"prefix_lens": [1, 64],
"max_chunk_capacity": 31,
"prefix_chunk_len": 15,
"num_prefix_chunks": 5,
"prefix_chunk_starts": torch.tensor(
[
[0, 0],
[15, 15],
[30, 30],
[45, 45],
[60, 60],
],
dtype=torch.int32,
),
"prefix_chunk_seq_lens": torch.tensor(
[
[1, 15],
[0, 15],
[0, 15],
[0, 15],
[0, 4],
],
dtype=torch.int32,
),
},
]
class MockForwardBatch(ForwardBatch):
def __init__(self, max_chunk_capacity: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_chunk_capacity = max_chunk_capacity
def get_max_chunk_capacity(self):
return self.max_chunk_capacity
class MockReqToTokenPool:
def __init__(self, batch_size, seq_len, device):
self.req_to_token = (
torch.arange(batch_size * seq_len, device=device)
.reshape(batch_size, seq_len)
.to(torch.int32)
)
# Test correctness of triton kernel for computing kv indices
def check_kv_indices(forward_batch):
for i in range(forward_batch.num_prefix_chunks):
computed_kv_indices = forward_batch.prefix_chunk_kv_indices[i]
req_to_token = forward_batch.req_to_token_pool.req_to_token[
: forward_batch.batch_size, :
]
ref_kv_indices = torch.empty(
forward_batch.prefix_chunk_num_tokens[i],
dtype=torch.int32,
device=computed_kv_indices.device,
)
running_ptr = 0
for j in range(forward_batch.batch_size):
seq_start = forward_batch.prefix_chunk_starts[i, j].item()
seq_len = forward_batch.prefix_chunk_seq_lens[i, j].item()
ref_kv_indices[running_ptr : running_ptr + seq_len].copy_(
req_to_token[j, seq_start : seq_start + seq_len]
)
running_ptr += seq_len
assert torch.allclose(computed_kv_indices, ref_kv_indices)
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
class TestPrefixChunkInfo(CustomTestCase):
def setUp(self):
# Common test parameters
self.num_local_heads = 128
self.kv_lora_rank = 512
self.qk_rope_head_dim = 64
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.extend_len = 64
self.max_bs = 4
self.max_seq_len = 128
# req_to_token_pool
self.req_to_token_pool = MockReqToTokenPool(
self.max_bs,
self.max_seq_len,
self.device,
)
# token_to_kv_pool
self.token_to_kv_pool = MLATokenToKVPool(
size=self.max_bs * self.max_seq_len,
page_size=1, # only consider page=1 for unit test
dtype=self.dtype,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
layer_num=1, # only consider layer=1 for unit test
device=self.device,
enable_memory_saver=False,
)
def test_prefix_chunk_info(self):
"""Test the standard extend operation."""
for test_case in TEST_CASES:
print(
f"Test case with batch_size={test_case['batch_size']}, prefix_lens={test_case['prefix_lens']}, max_chunk_capacity={test_case['max_chunk_capacity']}"
)
batch_size = test_case["batch_size"]
prefix_lens_cpu = test_case["prefix_lens"]
assert len(prefix_lens_cpu) == batch_size
prefix_lens = torch.tensor(prefix_lens_cpu, device=self.device)
max_chunk_capacity = test_case["max_chunk_capacity"]
seq_lens_cpu = [
self.extend_len + prefix_lens_cpu[i] for i in range(batch_size)
]
seq_lens = torch.tensor(seq_lens_cpu, device=self.device)
# Create forward batch
# input_ids and out_cache_loc are dummy tensors in this test
forward_batch = MockForwardBatch(
max_chunk_capacity=max_chunk_capacity,
batch_size=batch_size,
input_ids=torch.randint(
0, 100, (batch_size, self.extend_len), device=self.device
),
out_cache_loc=torch.arange(
self.max_bs * self.max_seq_len - batch_size * self.extend_len,
self.max_bs * self.max_seq_len,
device=self.device,
),
seq_lens_sum=sum(seq_lens_cpu),
forward_mode=ForwardMode.EXTEND,
req_pool_indices=torch.arange(batch_size, device=self.device),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
extend_prefix_lens=prefix_lens,
extend_prefix_lens_cpu=prefix_lens_cpu,
)
forward_batch.req_to_token_pool = self.req_to_token_pool
forward_batch.token_to_kv_pool = self.token_to_kv_pool
forward_batch.prepare_chunked_prefix_cache_info(self.device)
assert forward_batch.get_max_chunk_capacity() == max_chunk_capacity
assert forward_batch.prefix_chunk_len == test_case["prefix_chunk_len"]
assert forward_batch.num_prefix_chunks == test_case["num_prefix_chunks"]
assert torch.allclose(
forward_batch.prefix_chunk_starts,
test_case["prefix_chunk_starts"].to(self.device),
)
assert torch.allclose(
forward_batch.prefix_chunk_seq_lens,
test_case["prefix_chunk_seq_lens"].to(self.device),
)
check_kv_indices(forward_batch)
if __name__ == "__main__":
unittest.main()
......@@ -7,7 +7,6 @@ import torch
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
......@@ -19,7 +18,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p
"""
# Change to your own model if testing model is not public.
MODEL_USED_FOR_TEST = DEFAULT_MODEL_NAME_FOR_TEST
MODEL_USED_FOR_TEST_MLA = DEFAULT_MLA_MODEL_NAME_FOR_TEST
MODEL_USED_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
# Setting data path to None uses default data path in few_shot_gsm8k eval test.
DATA_PATH = None
......@@ -174,5 +173,57 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
self.assertGreater(avg_spec_accept_length, 1.5)
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled."""
model = MODEL_USED_FOR_TEST_MLA
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
"lmsys/sglang-ci-dsv3-test-NextN",
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"3",
]
)
return args
def test_gsm8k(self):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=DATA_PATH,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.5)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment