Unverified Commit 9fb48f95 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Support nextn for flashinfer mla attention backend (#4218)

parent 89ccb533
...@@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be ...@@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. - **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. Currently when using flashinfer mla wrapper and speculative decoding together, the `speculative_eagle_topk` parameter should be set to 1.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
......
...@@ -11,9 +11,10 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html ...@@ -11,9 +11,10 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
import torch import torch
import triton
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -23,6 +24,7 @@ from sglang.srt.layers.attention.flashinfer_backend import ( ...@@ -23,6 +24,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def __init__( def __init__(
self, self,
model_runner: ModelRunner, model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None,
): ):
super().__init__() super().__init__()
# Parse constants # Parse constants
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device self.device = model_runner.device
self.skip_prefill = skip_prefill
global_config.enable_flashinfer_mla = True global_config.enable_flashinfer_mla = True
...@@ -78,35 +84,51 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -78,35 +84,51 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.workspace_buffer = global_workspace_buffer self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device (max_bs + 1,), dtype=torch.int32, device=model_runner.device
) )
else:
self.kv_indptr = kv_indptr_buf
if not self.skip_prefill:
self.qo_indptr = torch.zeros( self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device (max_bs + 1,), dtype=torch.int32, device=model_runner.device
) )
if q_indptr_decode_buf is None:
self.q_indptr_decode = torch.arange( self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device 0, max_bs + 1, dtype=torch.int32, device=model_runner.device
) )
else:
self.q_indptr_decode = q_indptr_decode_buf
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD" self.workspace_buffer, "NHD"
) )
if not self.skip_prefill:
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper( self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, self.workspace_buffer,
backend="auto", backend="auto",
) )
# FlashinferMLA backend uses mla wrapper for target verify
self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
self.decode_wrapper = BatchMLAPagedAttentionWrapper( self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto" self.workspace_buffer, backend="auto"
) )
# Create indices updater # Create indices updater
if not skip_prefill:
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill( self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
model_runner, self model_runner, self
) )
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode( self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
model_runner, self model_runner, self
) )
...@@ -114,7 +136,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -114,7 +136,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
# Other metadata # Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} self.prefill_cuda_graph_metadata = {} # For verify
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
...@@ -126,6 +148,28 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -126,6 +148,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
init_metadata_replay=False, init_metadata_replay=False,
) )
self.forward_metadata = DecodeMetadata(self.decode_wrapper) self.forward_metadata = DecodeMetadata(self.decode_wrapper)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_wrapper_paged,
use_ragged=False,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_wrapper_verify,
use_ragged=False,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
...@@ -202,10 +246,33 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -202,10 +246,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum, seq_lens_sum,
decode_wrapper=decode_wrapper, decode_wrapper=decode_wrapper,
init_metadata_replay=False, init_metadata_replay=False,
spec_info=spec_info,
) )
self.decode_cuda_graph_metadata[bs] = decode_wrapper self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper) self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper) decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
elif forward_mode.is_target_verify():
verify_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:bs],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=verify_wrapper,
use_ragged=False,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
else: else:
raise ValueError(f"Invalid mode: {forward_mode=}") raise ValueError(f"Invalid mode: {forward_mode=}")
...@@ -221,6 +288,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -221,6 +288,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
kv_len_arr_cpu = seq_lens_cpu[:bs] kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum( self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0 kv_len_arr_cpu, dim=0
...@@ -239,8 +307,19 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -239,8 +307,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum, seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs], decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True, init_metadata_replay=True,
spec_info=spec_info,
**self.fast_decode_kwargs, **self.fast_decode_kwargs,
) )
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
spec_info=spec_info,
)
else: else:
raise ValueError(f"Invalid forward mode: {forward_mode=}") raise ValueError(f"Invalid forward mode: {forward_mode=}")
...@@ -254,7 +333,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -254,7 +333,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
v: torch.Tensor, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache: bool = True,
): ):
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
...@@ -297,7 +376,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -297,7 +376,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
v: torch.Tensor, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache: bool = True,
): ):
decode_wrapper = self.forward_metadata.decode_wrapper decode_wrapper = self.forward_metadata.decode_wrapper
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
...@@ -349,6 +428,7 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -349,6 +428,7 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper, decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False, init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs, **fast_decode_kwargs,
): ):
decode_wrapper = decode_wrapper or self.decode_wrapper decode_wrapper = decode_wrapper or self.decode_wrapper
...@@ -360,6 +440,7 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -360,6 +440,7 @@ class FlashInferMLAIndicesUpdaterDecode:
self.q_indptr, self.q_indptr,
self.kv_indptr, self.kv_indptr,
init_metadata_replay, init_metadata_replay,
spec_info,
**fast_decode_kwargs, **fast_decode_kwargs,
) )
...@@ -372,10 +453,14 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -372,10 +453,14 @@ class FlashInferMLAIndicesUpdaterDecode:
q_indptr: torch.Tensor, q_indptr: torch.Tensor,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
init_metadata_replay: bool = False, init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs, **fast_decode_kwargs,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1] q_indptr = q_indptr[: bs + 1]
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = ( kv_indices = (
...@@ -383,10 +468,6 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -383,10 +468,6 @@ class FlashInferMLAIndicesUpdaterDecode:
if not init_metadata_replay if not init_metadata_replay
else fast_decode_kwargs["kv_indices"] else fast_decode_kwargs["kv_indices"]
) )
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
...@@ -396,6 +477,9 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -396,6 +477,9 @@ class FlashInferMLAIndicesUpdaterDecode:
kv_indices, kv_indices,
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
if not init_metadata_replay: if not init_metadata_replay:
wrapper.plan( wrapper.plan(
q_indptr, q_indptr,
...@@ -457,6 +541,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -457,6 +541,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper, prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
): ):
if use_ragged: if use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
...@@ -476,6 +561,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -476,6 +561,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.kv_indptr, self.kv_indptr,
self.qo_indptr, self.qo_indptr,
use_ragged, use_ragged,
spec_info,
) )
def call_begin_forward( def call_begin_forward(
...@@ -490,8 +576,13 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -490,8 +576,13 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
): ):
bs = len(req_pool_indices) bs = len(seq_lens)
sm_scale = self.scaling
if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
...@@ -508,10 +599,22 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -508,10 +599,22 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indices, kv_indices,
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
sm_scale = self.scaling custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
# TODO: Support topk > 1 with custom mask
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
if use_ragged: if use_ragged:
# ragged prefill # ragged prefill
...@@ -543,6 +646,163 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -543,6 +646,163 @@ class FlashInferMLAIndicesUpdaterPrefill:
) )
class FlashInferMLAMultiStepDraftBackend:
"""
Wrap multiple flashinfer mla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
if topk > 1:
raise ValueError(
f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashInferMLAAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
def common_template(
self,
forward_batch: ForwardBatch,
kv_indices_buffer: torch.Tensor,
call_fn: Callable,
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device="cuda",
)
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def fast_mla_decode_plan( def fast_mla_decode_plan(
self, self,
qo_indptr_cpu: torch.Tensor, qo_indptr_cpu: torch.Tensor,
......
...@@ -555,6 +555,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -555,6 +555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return ( return (
not global_server_args_dict["flashinfer_mla_disable_ragged"] not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend() and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0 and forward_batch.extend_prefix_lens.sum() == 0
) )
else: else:
......
...@@ -123,6 +123,16 @@ class EAGLEWorker(TpModelWorker): ...@@ -123,6 +123,16 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
else: else:
raise ValueError( raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
......
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
import torch import torch
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
...@@ -100,5 +101,67 @@ class TestFlashinferMLANoRagged(unittest.TestCase): ...@@ -100,5 +101,67 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.62) self.assertGreater(metrics["accuracy"], 0.62)
class TestFlashinferMLAMTP(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--cuda-graph-max-bs",
"2",
"--disable-radix",
"--enable-torch-compile",
"--torch-compile-max-bs",
"1",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
"lmsys/sglang-ci-dsv3-test-NextN",
"--speculative-num-steps",
"4",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
"--enable-flashinfer-mla",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment