"vscode:/vscode.git/clone" did not exist on "3a0f86a66389ec78620db657ad48a1cdc878943e"
Unverified Commit 2e4babdb authored by quinnrong94's avatar quinnrong94 Committed by GitHub
Browse files

[Feat] Support FlashMLA backend with MTP and FP8 KV cache (#6109)


Co-authored-by: default avatarYingyi <yingyihuang2000@outlook.com>
Co-authored-by: default avatarneiltian <neiltian@tencent.com>
Co-authored-by: default avatarlukec <118525388+sleepcoo@users.noreply.github.com>
Co-authored-by: default avatarkexueyu <kexueyu@tencent.com>
Co-authored-by: vincentmeng's avatarvincentmeng <vincentmeng@tencent.com>
Co-authored-by: default avatarpengmeng <pengmeng@tencent.com>
parent 44a3783d
......@@ -8,6 +8,7 @@
| **FA3** | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Triton** | ❌ | ✅ | ✅ | ❌ | ❌ |
| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
## User guide
......@@ -30,10 +31,15 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --trust-r
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend triton
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --attention-backend triton --trust-remote-code
```
- Torch Native
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend torch_native
```
- FlashMLA
```bash
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
```
......@@ -158,7 +158,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8
```
- The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
- FlashAttention3 and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the FlashMLA backend and CutlassMLA backend is still under development.
- FlashAttention3 FlashMLA and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development.
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
......
......@@ -346,7 +346,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# Save kv cache
if save_kv_cache and k is not None:
......@@ -381,6 +380,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
)
else:
# mla paged prefill
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
if q_rope is None:
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q, q_rope = (
......@@ -442,7 +444,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :]
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
)
o = q_nope.new_empty(q_nope.shape)
# Direct call to run without the wrapper
......@@ -467,7 +471,7 @@ class FlashInferMLAIndicesUpdaterDecode:
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.data_type = model_runner.dtype
self.attn_backend = attn_backend
# Buffers and wrappers
......@@ -577,7 +581,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.data_type = model_runner.dtype
self.q_data_type = model_runner.dtype
self.attn_backend = attn_backend
......
......@@ -8,7 +8,7 @@ Enable speculative sampling in FlashMLA
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
import torch
import triton
......@@ -30,8 +30,8 @@ if TYPE_CHECKING:
# FlashMLA only supports pagesize=64
PAGE_SIZE = 64
# TODO The current setup is hard-coded and will be changed after integrating with MTP.
Q_LEN = 1
# FlashMLA FP8 issue: https://github.com/deepseek-ai/FlashMLA/issues/56
@dataclass
......@@ -52,7 +52,7 @@ class FlashMLADecodeMetadata:
class FlashMLABackend(FlashInferMLAAttnBackend):
"""Flashinfer attention kernels."""
"""Flashmla attention kernels."""
def __init__(
self,
......@@ -82,42 +82,72 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.q_data_type = model_runner.dtype
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
spec_info = forward_batch.spec_info
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads,
1,
)
self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices,
)
else:
super().init_forward_metadata(forward_batch)
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads,
1,
)
self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices,
)
elif forward_batch.forward_mode.is_target_verify():
seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_draft_tokens
seq_lens = forward_batch.seq_lens + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
)
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices,
)
else:
super().init_forward_metadata(forward_batch)
......@@ -136,11 +166,22 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
else:
cuda_graph_kv_indices = block_kv_indices
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
Q_LEN * self.num_q_heads,
1,
)
if self.num_draft_tokens:
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
torch.ones(
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
),
self.num_draft_tokens * self.num_q_heads,
1,
)
else:
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
torch.ones(
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
),
self.num_q_heads,
1,
)
self.cuda_graph_kv_indices = cuda_graph_kv_indices
def init_forward_metadata_capture_cuda_graph(
......@@ -154,31 +195,54 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = FlashMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = FlashMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
elif forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = FlashMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
else:
super().init_forward_metadata_capture_cuda_graph(
bs,
......@@ -218,7 +282,32 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads,
self.num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
elif forward_mode.is_target_verify():
seq_lens = seq_lens[:bs] + self.num_draft_tokens
seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
......@@ -228,7 +317,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
else:
super().init_forward_metadata_replay_cuda_graph(
bs,
......@@ -268,17 +356,191 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
if self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
o, _ = flash_mla_with_kvcache(
q=reshape_q_fp8,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=True,
descale_q=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
descale_k=torch.ones((1), dtype=torch.float32, device=reshape_q.device),
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
# todo: need check all causal True or False?
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=True,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
if (
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
return super().forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
else:
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
if self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
o, _ = flash_mla_with_kvcache(
q=reshape_q_fp8,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32)
+ self.num_draft_tokens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=True,
descale_q=torch.ones(
(1), dtype=torch.float32, device=reshape_q.device
),
descale_k=torch.ones(
(1), dtype=torch.float32, device=reshape_q.device
),
)
else:
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32)
+ self.num_draft_tokens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=True,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=self.forward_metadata.block_kv_indices,
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=False,
# TODO: multi step kv indices optimization
class FlashMLAMultiStepDraftBackend:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
if topk > 1:
raise ValueError(
f"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashMLABackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=None,
)
)
def common_template(
self,
forward_batch: ForwardBatch,
call_fn: Callable,
):
assert forward_batch.spec_info is not None
for i in range(self.speculative_num_steps - 1):
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
self.common_template(forward_batch, call_fn)
......@@ -77,8 +77,8 @@ def create_flashmla_kv_indices_triton(
) * PAGED_SIZE
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
mask = paged_offset <= num_paged * PAGED_SIZE
mask_out = paged_offset_out <= num_paged
mask = paged_offset < num_paged * PAGED_SIZE
mask_out = paged_offset_out < num_paged
data = tl.load(
req_to_token_ptr
......
......@@ -30,6 +30,7 @@ from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_captur
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
......@@ -210,7 +211,10 @@ class CudaGraphRunner:
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
if global_server_args_dict["attention_backend"] == "flashmla":
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
else:
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
......
......@@ -199,6 +199,19 @@ class EAGLEWorker(TpModelWorker):
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
else:
raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
......
......@@ -6,6 +6,7 @@ python3 test/srt/test_flashmla.py
import unittest
from types import SimpleNamespace
import requests
import torch
from sglang.srt.utils import kill_process_tree
......@@ -14,6 +15,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_one_batch,
......@@ -81,5 +83,71 @@ class TestFlashMLAAttnLatency(unittest.TestCase):
self.assertGreater(output_throughput, 100)
class TestFlashMLAMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--cuda-graph-max-bs",
"4",
"--disable-radix",
"--enable-torch-compile",
"--torch-compile-max-bs",
"1",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
"lmsys/sglang-ci-dsv3-test-NextN",
"--speculative-num-steps",
"1",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"2",
"--attention-backend",
"flashmla",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
print(f"{server_info=}")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.8)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment