Unverified Commit 9fb48f95 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Support nextn for flashinfer mla attention backend (#4218)

parent 89ccb533
......@@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. Currently when using flashinfer mla wrapper and speculative decoding together, the `speculative_eagle_topk` parameter should be set to 1.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
......
......@@ -11,9 +11,10 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import triton
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -23,6 +24,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
......@@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.skip_prefill = skip_prefill
global_config.enable_flashinfer_mla = True
......@@ -78,35 +84,51 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
if q_indptr_decode_buf is None:
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
else:
self.q_indptr_decode = q_indptr_decode_buf
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
if not self.skip_prefill:
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
# FlashinferMLA backend uses mla wrapper for target verify
self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto"
)
# Create indices updater
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
model_runner, self
)
if not skip_prefill:
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
model_runner, self
)
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
model_runner, self
)
......@@ -114,7 +136,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
......@@ -126,6 +148,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_wrapper_paged,
use_ragged=False,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_wrapper_verify,
use_ragged=False,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
else:
prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
......@@ -202,10 +246,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum,
decode_wrapper=decode_wrapper,
init_metadata_replay=False,
spec_info=spec_info,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
elif forward_mode.is_target_verify():
verify_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.cuda_graph_kv_lens[:bs],
backend="auto",
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=verify_wrapper,
use_ragged=False,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
......@@ -221,6 +288,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
......@@ -239,8 +307,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
spec_info=spec_info,
**self.fast_decode_kwargs,
)
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
spec_info=spec_info,
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
......@@ -254,7 +333,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
save_kv_cache: bool = True,
):
cache_loc = forward_batch.out_cache_loc
......@@ -297,7 +376,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
save_kv_cache: bool = True,
):
decode_wrapper = self.forward_metadata.decode_wrapper
cache_loc = forward_batch.out_cache_loc
......@@ -349,6 +428,7 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
......@@ -360,6 +440,7 @@ class FlashInferMLAIndicesUpdaterDecode:
self.q_indptr,
self.kv_indptr,
init_metadata_replay,
spec_info,
**fast_decode_kwargs,
)
......@@ -372,30 +453,33 @@ class FlashInferMLAIndicesUpdaterDecode:
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = (
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
)
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = (
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
)
if not init_metadata_replay:
wrapper.plan(
q_indptr,
......@@ -457,6 +541,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
prefix_lens: torch.Tensor,
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
if use_ragged:
paged_kernel_lens = prefix_lens
......@@ -476,6 +561,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.kv_indptr,
self.qo_indptr,
use_ragged,
spec_info,
)
def call_begin_forward(
......@@ -490,29 +576,46 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
bs = len(seq_lens)
sm_scale = self.scaling
if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
# TODO: Support topk > 1 with custom mask
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
if use_ragged:
# ragged prefill
wrapper_ragged.begin_forward(
......@@ -543,6 +646,163 @@ class FlashInferMLAIndicesUpdaterPrefill:
)
class FlashInferMLAMultiStepDraftBackend:
"""
Wrap multiple flashinfer mla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
if topk > 1:
raise ValueError(
f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashInferMLAAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
def common_template(
self,
forward_batch: ForwardBatch,
kv_indices_buffer: torch.Tensor,
call_fn: Callable,
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device="cuda",
)
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
......
......@@ -555,6 +555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
else:
......
......@@ -123,6 +123,16 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
else:
raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
......
import unittest
from types import SimpleNamespace
import requests
import torch
from sglang.srt.utils import kill_process_tree
......@@ -100,5 +101,67 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.62)
class TestFlashinferMLAMTP(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--cuda-graph-max-bs",
"2",
"--disable-radix",
"--enable-torch-compile",
"--torch-compile-max-bs",
"1",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
"lmsys/sglang-ci-dsv3-test-NextN",
"--speculative-num-steps",
"4",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
"--enable-flashinfer-mla",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.5)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment