Unverified Commit fa561067 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Add fast decode plan for flashinfer mla (#3987)

parent 7fbab730
......@@ -186,5 +186,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden.
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
......@@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage)
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
......
......@@ -29,9 +29,8 @@ class AttentionBackend(ABC):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
......@@ -42,9 +41,8 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
......
......@@ -269,9 +269,10 @@ class FlashInferAttnBackend(AttentionBackend):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
......@@ -339,9 +340,10 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
......
......@@ -10,6 +10,7 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Optional, Union
import torch
......@@ -27,14 +28,12 @@ from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available():
from flashinfer import (
BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
@dataclass
......@@ -63,6 +62,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
global_config.enable_flashinfer_mla = True
......@@ -85,10 +85,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
......@@ -126,6 +122,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
decode_wrapper=self.decode_wrapper,
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
else:
......@@ -161,13 +158,20 @@ class FlashInferMLAAttnBackend(AttentionBackend):
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = cuda_graph_kv_indices
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
self.cuda_graph_kv_lens = torch.ones(
(max_bs,), dtype=torch.int32, device=self.device
)
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
self.cuda_graph_qo_indptr = self.kv_indptr.clone()
# For fast decode plan in graph replaying
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
self.fast_decode_kwargs = {
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
"kv_indices": self.cuda_graph_kv_indices,
}
def init_forward_metadata_capture_cuda_graph(
self,
......@@ -175,18 +179,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
**kwargs,
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.qo_indptr[: num_tokens + 1],
kv_indptr=self.kv_indptr[: num_tokens + 1],
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices,
kv_len_arr=self.kv_last_page_len[:num_tokens],
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
backend="auto",
)
......@@ -196,9 +199,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens,
seq_lens_sum,
decode_wrapper=decode_wrapper,
init_metadata_replay=False,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
......@@ -208,16 +213,30 @@ class FlashInferMLAAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: torch.Tensor,
**kwargs,
):
if forward_mode.is_decode_or_idle():
kv_len_arr_cpu = seq_lens_cpu[:bs]
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
kv_len_arr_cpu, dim=0
)
self.fast_decode_kwargs.update(
{
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
"kv_len_arr_cpu": kv_len_arr_cpu,
}
)
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
decode_wrapper=self.decode_cuda_graph_metadata[bs],
init_metadata_replay=True,
**self.fast_decode_kwargs,
)
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
......@@ -317,7 +336,6 @@ class FlashInferMLAIndicesUpdaterDecode:
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode
......@@ -327,6 +345,8 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward(
......@@ -336,6 +356,8 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum,
self.q_indptr,
self.kv_indptr,
init_metadata_replay,
**fast_decode_kwargs,
)
def call_begin_forward(
......@@ -346,14 +368,19 @@ class FlashInferMLAIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
kv_indices = (
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
if not init_metadata_replay
else fast_decode_kwargs["kv_indices"]
)
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
......@@ -366,21 +393,36 @@ class FlashInferMLAIndicesUpdaterDecode:
kv_indices,
self.req_to_token.shape[1],
)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
if not init_metadata_replay:
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
else:
wrapper.plan(
fast_decode_kwargs["qo_indptr_cpu"],
fast_decode_kwargs["kv_indptr_cpu"],
kv_indices,
fast_decode_kwargs["kv_len_arr_cpu"],
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
class FlashInferMLAIndicesUpdaterPrefill:
......@@ -400,7 +442,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
......@@ -497,3 +538,42 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.q_data_type,
self.data_type,
)
def fast_mla_decode_plan(
self,
qo_indptr_cpu: torch.Tensor,
kv_indptr_cpu: torch.Tensor,
kv_indices: torch.Tensor,
kv_len_arr_cpu: torch.Tensor,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,
) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
cuda graph replaying.
"""
self._causal = causal
self._page_size = page_size
self._sm_scale = sm_scale
with self.device as device:
stream = torch.cuda.current_stream(device).cuda_stream
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
stream,
)
......@@ -230,9 +230,10 @@ class TritonAttnBackend(AttentionBackend):
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
**kwargs,
):
assert encoder_lens is None, "Not supported"
......@@ -308,9 +309,10 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
**kwargs,
):
# NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle():
......
......@@ -582,6 +582,9 @@ class ScheduleBatch:
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
# For decode
decode_seq_lens: List[int] = None
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
......@@ -1168,8 +1171,10 @@ class ScheduleBatch:
def get_model_worker_batch(self):
if self.forward_mode.is_decode_or_idle():
decode_seq_lens = self.seq_lens.cpu()
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
decode_seq_lens = None
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
......@@ -1194,6 +1199,7 @@ class ScheduleBatch:
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
decode_seq_lens=decode_seq_lens,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
......@@ -1267,6 +1273,9 @@ class ModelWorkerBatch:
global_num_tokens: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For decode
decode_seq_lens: Optional[torch.Tensor]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
......
......@@ -199,6 +199,10 @@ class CudaGraphRunner:
if self.enable_torch_compile:
set_torch_compile_config()
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
......@@ -369,9 +373,9 @@ class CudaGraphRunner:
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
encoder_lens=encoder_lens,
spec_info=forward_batch.spec_info,
)
# Run and capture
......@@ -434,6 +438,7 @@ class CudaGraphRunner:
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
self.seq_lens_cpu.fill_(1)
# Common inputs
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
......@@ -441,6 +446,8 @@ class CudaGraphRunner:
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
......@@ -456,9 +463,10 @@ class CudaGraphRunner:
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs),
self.encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
encoder_lens=self.encoder_lens,
spec_info=forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
)
# Replay
......
......@@ -152,6 +152,9 @@ class ForwardBatch:
# Position information
positions: torch.Tensor = None
# For decode
decode_seq_lens_cpu: Optional[torch.Tensor] = None
# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
......@@ -256,6 +259,8 @@ class ForwardBatch:
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
if ret.decode_seq_lens_cpu is None:
ret.decode_seq_lens_cpu = batch.decode_seq_lens
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment