Unverified Commit 90a4b7d9 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Feature]Support ragged prefill in flashinfer mla backend (#3967)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarpankajroark <pankajroark@users.noreply.github.com>
parent f3b99f73
......@@ -133,7 +133,6 @@ Please consult the documentation below to learn more about the parameters you ma
* `attention_backend`: The backend for attention computation and KV cache management.
* `sampling_backend`: The backend for sampling.
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. (In Experiment Stage)
## Constrained Decoding
......@@ -186,3 +185,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models.
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
......@@ -2,13 +2,13 @@ from __future__ import annotations
"""
Support attention backend for flashinfer MLA.
When radix cache is enabled, the backend only uses BatchMLAPaged wrapper when forwarding.
When radix cache is disabled, the backend uses BatchPrefill wrappers for prefilling (with or without prefix cache),
The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.
When it's set to false, all wrappers are BatchMLAPaged wrapper.
When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,
and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html
"""
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
......@@ -18,7 +18,6 @@ from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
should_use_tensor_core,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
......@@ -32,11 +31,10 @@ if TYPE_CHECKING:
if is_flashinfer_available():
from flashinfer import (
BatchPrefillWithPagedKVCacheWrapper,
BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.mla import BatchMLAPagedAttentionWrapper
@dataclass
......@@ -46,9 +44,7 @@ class DecodeMetadata:
@dataclass
class PrefillMetadata:
prefill_wrapper: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
]
prefill_wrapper: BatchMLAPagedAttentionWrapper
use_ragged: bool
......@@ -62,7 +58,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def __init__(
self,
model_runner: ModelRunner,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__()
......@@ -82,12 +77,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
......@@ -97,22 +89,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
if not global_server_args_dict["disable_radix_cache"]:
# use mla paged prefill
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
else:
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="auto",
)
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="auto",
)
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto"
)
......@@ -141,7 +130,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
else:
prefix_lens = forward_batch.extend_prefix_lens
use_ragged = global_server_args_dict["disable_radix_cache"]
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and extend_no_prefix
)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
......@@ -241,45 +234,37 @@ class FlashInferMLAAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
if not global_server_args_dict["disable_radix_cache"]:
# use mla paged prefill
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# Save kv cache
if save_kv_cache and k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)
else:
# use mla ragged prefill
if self.forward_metadata.use_ragged:
# ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
# FIXME: Here should be another prefill_paged to call
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
# mla paged prefill
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......@@ -334,6 +319,7 @@ class FlashInferMLAIndicesUpdaterDecode:
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode
def update(
self,
......@@ -342,12 +328,13 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
):
decode_wrappers = decode_wrapper or self.decode_wrapper
decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward(
decode_wrapper,
req_pool_indices,
seq_lens,
seq_lens_sum,
self.q_indptr,
self.kv_indptr,
)
......@@ -357,14 +344,19 @@ class FlashInferMLAIndicesUpdaterDecode:
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
):
bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
......@@ -375,9 +367,6 @@ class FlashInferMLAIndicesUpdaterDecode:
self.req_to_token.shape[1],
)
sm_scale = self.scaling
q_indptr = torch.arange(0, bs + 1).to(0).int()
kv_lens = paged_kernel_lens.to(torch.int32)
wrapper.plan(
q_indptr,
kv_indptr,
......@@ -397,12 +386,9 @@ class FlashInferMLAIndicesUpdaterDecode:
class FlashInferMLAIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_qo_heads = (
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
......@@ -425,9 +411,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool,
):
if use_ragged:
......@@ -453,9 +437,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
wrapper_paged: BatchMLAPagedAttentionWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
......@@ -466,7 +448,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
use_ragged: bool,
):
bs = len(req_pool_indices)
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
......@@ -488,19 +469,18 @@ class FlashInferMLAIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1]
sm_scale = self.scaling
# extend part
if use_ragged:
# ragged prefill
wrapper_ragged.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
)
if not global_server_args_dict["disable_radix_cache"]:
else:
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan(
......@@ -508,7 +488,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr,
kv_indices,
kv_len_arr,
self.num_qo_heads,
self.num_local_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
......@@ -517,5 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.q_data_type,
self.data_type,
)
# FIXME: Here should be some logic for prefill paged when not using radix cache?
......@@ -67,6 +67,7 @@ global_server_args_dict = {
"device": ServerArgs.device,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
}
logger = logging.getLogger(__name__)
......
......@@ -182,6 +182,7 @@ class ModelRunner:
"device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
}
)
......
......@@ -520,10 +520,11 @@ class DeepseekV2AttentionMLA(nn.Module):
def no_absorb() -> bool:
if global_server_args_dict["enable_flashinfer_mla"]:
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
global_server_args_dict["disable_radix_cache"]
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
......
......@@ -167,6 +167,7 @@ class ServerArgs:
tool_call_parser: str = None
enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False
flashinfer_mla_disable_ragged: bool = False
def __post_init__(self):
# Set missing default values
......@@ -713,6 +714,11 @@ class ServerArgs:
action="store_true",
help="Enable FlashInfer MLA optimization",
)
parser.add_argument(
"--flashinfer-mla-disable-ragged",
action="store_true",
help="Not using ragged prefill wrapper when running flashinfer mla",
)
# Speculative decoding
parser.add_argument(
......
......@@ -23,6 +23,7 @@ suites = {
"test_gguf.py",
"test_input_embeddings.py",
"test_mla.py",
"test_mla_flashinfer.py",
"test_mla_fp8.py",
"test_json_constrained.py",
"test_large_max_new_tokens.py",
......
import unittest
from types import SimpleNamespace
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestFlashinferMLA(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "sgl-project/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--enable-torch-compile",
"--cuda-graph-max-bs",
"2",
"--enable-flashinfer-mla",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
class TestFlashinferMLANoRagged(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "sgl-project/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--enable-torch-compile",
"--disable-cuda-graph",
"--cuda-graph-max-bs",
"2",
"--enable-flashinfer-mla",
"--flashinfer-mla-disable-ragged",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment