Unverified Commit 90a4b7d9 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Feature]Support ragged prefill in flashinfer mla backend (#3967)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarpankajroark <pankajroark@users.noreply.github.com>
parent f3b99f73
...@@ -133,7 +133,6 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -133,7 +133,6 @@ Please consult the documentation below to learn more about the parameters you ma
* `attention_backend`: The backend for attention computation and KV cache management. * `attention_backend`: The backend for attention computation and KV cache management.
* `sampling_backend`: The backend for sampling. * `sampling_backend`: The backend for sampling.
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. (In Experiment Stage)
## Constrained Decoding ## Constrained Decoding
...@@ -186,3 +185,5 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -186,3 +185,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models.
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
...@@ -2,13 +2,13 @@ from __future__ import annotations ...@@ -2,13 +2,13 @@ from __future__ import annotations
""" """
Support attention backend for flashinfer MLA. Support attention backend for flashinfer MLA.
When radix cache is enabled, the backend only uses BatchMLAPaged wrapper when forwarding. The flashinfer_mla_disable_ragged flag controls whether to use ragged prefill wrapper and defaults to be false.
When radix cache is disabled, the backend uses BatchPrefill wrappers for prefilling (with or without prefix cache), When it's set to false, all wrappers are BatchMLAPaged wrapper.
When it's set to true, the backend uses BatchRagged and BatchMLAPaged wrapper for prefilling,
and uses BatchMLAPaged wrapper for decoding. and uses BatchMLAPaged wrapper for decoding.
More details can be found in https://docs.flashinfer.ai/api/mla.html More details can be found in https://docs.flashinfer.ai/api/mla.html
""" """
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
...@@ -18,7 +18,6 @@ from sglang.global_config import global_config ...@@ -18,7 +18,6 @@ from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
should_use_tensor_core,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -32,11 +31,10 @@ if TYPE_CHECKING: ...@@ -32,11 +31,10 @@ if TYPE_CHECKING:
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import ( from flashinfer import (
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from flashinfer.mla import BatchMLAPagedAttentionWrapper
@dataclass @dataclass
...@@ -46,9 +44,7 @@ class DecodeMetadata: ...@@ -46,9 +44,7 @@ class DecodeMetadata:
@dataclass @dataclass
class PrefillMetadata: class PrefillMetadata:
prefill_wrapper: Union[ prefill_wrapper: BatchMLAPagedAttentionWrapper
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
]
use_ragged: bool use_ragged: bool
...@@ -62,7 +58,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -62,7 +58,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def __init__( def __init__(
self, self,
model_runner: ModelRunner, model_runner: ModelRunner,
kv_indptr_buf: Optional[torch.Tensor] = None,
): ):
super().__init__() super().__init__()
...@@ -82,12 +77,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -82,12 +77,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.workspace_buffer = global_workspace_buffer self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None: self.kv_indptr = torch.zeros(
self.kv_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device
(max_bs + 1,), dtype=torch.int32, device=model_runner.device )
)
else:
self.kv_indptr = kv_indptr_buf
self.qo_indptr = torch.zeros( self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device (max_bs + 1,), dtype=torch.int32, device=model_runner.device
...@@ -97,22 +89,19 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -97,22 +89,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
(max_bs,), dtype=torch.int32, device=model_runner.device (max_bs,), dtype=torch.int32, device=model_runner.device
) )
self.q_indptr_decode = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
)
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD" self.workspace_buffer, "NHD"
) )
if not global_server_args_dict["disable_radix_cache"]: self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
# use mla paged prefill self.workspace_buffer,
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper( backend="auto",
self.workspace_buffer, )
backend="auto",
)
else:
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="auto",
)
self.decode_wrapper = BatchMLAPagedAttentionWrapper( self.decode_wrapper = BatchMLAPagedAttentionWrapper(
self.workspace_buffer, backend="auto" self.workspace_buffer, backend="auto"
) )
...@@ -141,7 +130,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -141,7 +130,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.forward_metadata = DecodeMetadata(self.decode_wrapper) self.forward_metadata = DecodeMetadata(self.decode_wrapper)
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
use_ragged = global_server_args_dict["disable_radix_cache"] extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and extend_no_prefix
)
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
...@@ -241,45 +234,37 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -241,45 +234,37 @@ class FlashInferMLAAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
if not global_server_args_dict["disable_radix_cache"]: # Save kv cache
# use mla paged prefill if save_kv_cache and k is not None:
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper assert v is not None
if k is not None: if save_kv_cache:
assert v is not None forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = prefill_wrapper_paged.run( if self.forward_metadata.use_ragged:
qall[:, :, : layer.v_head_dim], # ragged prefill
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)
else:
# use mla ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse( o, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim), qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim), v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
causal=True, causal=True,
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
) )
else:
# FIXME: Here should be another prefill_paged to call # mla paged prefill
o = prefill_wrapper_paged.run(
if save_kv_cache: qall[:, :, : layer.v_head_dim],
forward_batch.token_to_kv_pool.set_kv_buffer( qall[:, :, layer.v_head_dim :],
layer, k_buf[:, :, : layer.v_head_dim],
cache_loc, k_buf[:, :, layer.v_head_dim :],
k, )
v,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -334,6 +319,7 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -334,6 +319,7 @@ class FlashInferMLAIndicesUpdaterDecode:
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.q_indptr = attn_backend.q_indptr_decode
def update( def update(
self, self,
...@@ -342,12 +328,13 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -342,12 +328,13 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper, decode_wrapper: BatchMLAPagedAttentionWrapper,
): ):
decode_wrappers = decode_wrapper or self.decode_wrapper decode_wrapper = decode_wrapper or self.decode_wrapper
self.call_begin_forward( self.call_begin_forward(
decode_wrapper, decode_wrapper,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
seq_lens_sum, seq_lens_sum,
self.q_indptr,
self.kv_indptr, self.kv_indptr,
) )
...@@ -357,14 +344,19 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -357,14 +344,19 @@ class FlashInferMLAIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
q_indptr = q_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda" paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
) )
kv_lens = paged_kernel_lens.to(torch.int32)
sm_scale = self.scaling
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
...@@ -375,9 +367,6 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -375,9 +367,6 @@ class FlashInferMLAIndicesUpdaterDecode:
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
sm_scale = self.scaling
q_indptr = torch.arange(0, bs + 1).to(0).int()
kv_lens = paged_kernel_lens.to(torch.int32)
wrapper.plan( wrapper.plan(
q_indptr, q_indptr,
kv_indptr, kv_indptr,
...@@ -397,12 +386,9 @@ class FlashInferMLAIndicesUpdaterDecode: ...@@ -397,12 +386,9 @@ class FlashInferMLAIndicesUpdaterDecode:
class FlashInferMLAIndicesUpdaterPrefill: class FlashInferMLAIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants # Parse Constants
self.num_qo_heads = ( self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
...@@ -425,9 +411,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -425,9 +411,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrapper_paged: Union[ prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
use_ragged: bool, use_ragged: bool,
): ):
if use_ragged: if use_ragged:
...@@ -453,9 +437,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -453,9 +437,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
def call_begin_forward( def call_begin_forward(
self, self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: Union[ wrapper_paged: BatchMLAPagedAttentionWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
...@@ -466,7 +448,6 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -466,7 +448,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
use_ragged: bool, use_ragged: bool,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
...@@ -488,19 +469,18 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -488,19 +469,18 @@ class FlashInferMLAIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
sm_scale = self.scaling sm_scale = self.scaling
# extend part
if use_ragged: if use_ragged:
# ragged prefill
wrapper_ragged.begin_forward( wrapper_ragged.begin_forward(
qo_indptr=qo_indptr, qo_indptr=qo_indptr,
kv_indptr=qo_indptr, kv_indptr=qo_indptr,
num_qo_heads=self.num_qo_heads, num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim, head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
) )
else:
if not global_server_args_dict["disable_radix_cache"]:
# mla paged prefill # mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1] kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan( wrapper_paged.plan(
...@@ -508,7 +488,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -508,7 +488,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_len_arr, kv_len_arr,
self.num_qo_heads, self.num_local_heads,
self.kv_lora_rank, self.kv_lora_rank,
self.qk_rope_head_dim, self.qk_rope_head_dim,
1, 1,
...@@ -517,5 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -517,5 +497,3 @@ class FlashInferMLAIndicesUpdaterPrefill:
self.q_data_type, self.q_data_type,
self.data_type, self.data_type,
) )
# FIXME: Here should be some logic for prefill paged when not using radix cache?
...@@ -67,6 +67,7 @@ global_server_args_dict = { ...@@ -67,6 +67,7 @@ global_server_args_dict = {
"device": ServerArgs.device, "device": ServerArgs.device,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -182,6 +182,7 @@ class ModelRunner: ...@@ -182,6 +182,7 @@ class ModelRunner:
"device": server_args.device, "device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla, "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache, "disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
} }
) )
......
...@@ -520,10 +520,11 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -520,10 +520,11 @@ class DeepseekV2AttentionMLA(nn.Module):
def no_absorb() -> bool: def no_absorb() -> bool:
if global_server_args_dict["enable_flashinfer_mla"]: if global_server_args_dict["enable_flashinfer_mla"]:
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache # Flashinfer MLA: Do not absorb when enabling ragged prefill
return ( return (
global_server_args_dict["disable_radix_cache"] not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend() and forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0
) )
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
......
...@@ -167,6 +167,7 @@ class ServerArgs: ...@@ -167,6 +167,7 @@ class ServerArgs:
tool_call_parser: str = None tool_call_parser: str = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False enable_flashinfer_mla: bool = False
flashinfer_mla_disable_ragged: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
...@@ -713,6 +714,11 @@ class ServerArgs: ...@@ -713,6 +714,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable FlashInfer MLA optimization", help="Enable FlashInfer MLA optimization",
) )
parser.add_argument(
"--flashinfer-mla-disable-ragged",
action="store_true",
help="Not using ragged prefill wrapper when running flashinfer mla",
)
# Speculative decoding # Speculative decoding
parser.add_argument( parser.add_argument(
......
...@@ -23,6 +23,7 @@ suites = { ...@@ -23,6 +23,7 @@ suites = {
"test_gguf.py", "test_gguf.py",
"test_input_embeddings.py", "test_input_embeddings.py",
"test_mla.py", "test_mla.py",
"test_mla_flashinfer.py",
"test_mla_fp8.py", "test_mla_fp8.py",
"test_json_constrained.py", "test_json_constrained.py",
"test_large_max_new_tokens.py", "test_large_max_new_tokens.py",
......
import unittest
from types import SimpleNamespace
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestFlashinferMLA(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "sgl-project/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--enable-torch-compile",
"--cuda-graph-max-bs",
"2",
"--enable-flashinfer-mla",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
class TestFlashinferMLANoRagged(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "sgl-project/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--enable-torch-compile",
"--disable-cuda-graph",
"--cuda-graph-max-bs",
"2",
"--enable-flashinfer-mla",
"--flashinfer-mla-disable-ragged",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment