"vscode:/vscode.git/clone" did not exist on "2bf481fb0499370b2ef30f182a3d3a82f5382d11"
Unverified Commit efbae697 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Revision] Replace enable_flashinfer_mla argument with attention_backend (#5052)

parent ca8d02ab
...@@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma
## Kernel backend ## Kernel backend
* `attention_backend`: The backend for attention computation and KV cache management. * `attention_backend`: This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend.
* `sampling_backend`: The backend for sampling. * `sampling_backend`: The backend for sampling.
## Constrained Decoding ## Constrained Decoding
...@@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden. * `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.**
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on. * `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
...@@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be ...@@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. - **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
...@@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec ...@@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec
``` ```
- The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script. - The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script.
- The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. - The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
- Currently when using flashinfer mla wrapper (`--enable-flashinfer-mla`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. When using FlashInfer MLA wrapper (`--attention-backend flashinfer`) with speculative decoding, set the `--speculative-eagle-topk` parameter to `1`. The FlashAttention 3 backend also only supports `--speculative-eagle-topk 1`.
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
......
...@@ -71,8 +71,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -71,8 +71,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.device = model_runner.device self.device = model_runner.device
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
global_config.enable_flashinfer_mla = True
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_workspace_buffer
if global_workspace_buffer is None: if global_workspace_buffer is None:
......
...@@ -76,7 +76,6 @@ global_server_args_dict = { ...@@ -76,7 +76,6 @@ global_server_args_dict = {
"device": ServerArgs.device, "device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"enable_flashmla": ServerArgs.enable_flashmla, "enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
...@@ -1437,7 +1436,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1437,7 +1436,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Create seq_lens_cpu when needed # Create seq_lens_cpu when needed
if ( if (
global_server_args_dict["enable_flashinfer_mla"] (
global_server_args_dict["use_mla_backend"]
and global_server_args_dict["attention_backend"] == "flashinfer"
)
or global_server_args_dict["enable_flashmla"] or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3" or global_server_args_dict["attention_backend"] == "fa3"
): ):
......
...@@ -75,6 +75,7 @@ from sglang.srt.utils import ( ...@@ -75,6 +75,7 @@ from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
init_custom_process_group, init_custom_process_group,
is_cuda, is_cuda,
is_flashinfer_available,
is_hip, is_hip,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
...@@ -123,6 +124,10 @@ class ModelRunner: ...@@ -123,6 +124,10 @@ class ModelRunner:
self.page_size = server_args.page_size self.page_size = server_args.page_size
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.use_mla_backend = (
self.model_config.attention_arch == AttentionArch.MLA
and not server_args.disable_mla
)
# Model-specific adjustment # Model-specific adjustment
self.model_specific_adjustment() self.model_specific_adjustment()
...@@ -151,7 +156,6 @@ class ModelRunner: ...@@ -151,7 +156,6 @@ class ModelRunner:
"device": server_args.device, "device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"enable_flashmla": server_args.enable_flashmla, "enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache, "disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
...@@ -159,6 +163,7 @@ class ModelRunner: ...@@ -159,6 +163,7 @@ class ModelRunner:
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion, "n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion, "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"use_mla_backend": self.use_mla_backend,
} }
) )
...@@ -219,27 +224,38 @@ class ModelRunner: ...@@ -219,27 +224,38 @@ class ModelRunner:
def model_specific_adjustment(self): def model_specific_adjustment(self):
server_args = self.server_args server_args = self.server_args
if ( if server_args.enable_flashinfer_mla:
self.model_config.attention_arch == AttentionArch.MLA # TODO: remove this branch after enable_flashinfer_mla is deprecated
and not server_args.disable_mla logger.info("MLA optimization is turned on. Use flashinfer backend.")
): server_args.attention_backend = "flashinfer"
elif server_args.enable_flashmla:
# TODO: remove this branch after enable_flashmla is deprecated
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
if not self.use_mla_backend:
server_args.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
server_args.attention_backend = "triton"
logger.info(
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
)
elif self.use_mla_backend:
# TODO: add MLA optimization on CPU # TODO: add MLA optimization on CPU
if server_args.device != "cpu": if server_args.device != "cpu":
if server_args.enable_flashinfer_mla: if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
logger.info( logger.info(
"MLA optimization is turned on. Use flashinfer mla backend." f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)
server_args.attention_backend = "flashinfer_mla"
elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend == "fa3":
logger.info(
f"MLA optimization is turned on. Use flash attention 3 backend."
) )
else: else:
logger.info("MLA optimization is turned on. Use triton backend.") raise ValueError(
server_args.attention_backend = "triton" f"Invalid attention backend for MLA: {server_args.attention_backend}"
)
else:
raise ValueError(f"MLA optimization not supported on CPU.")
if server_args.enable_double_sparsity: if server_args.enable_double_sparsity:
logger.info( logger.info(
...@@ -637,10 +653,7 @@ class ModelRunner: ...@@ -637,10 +653,7 @@ class ModelRunner:
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1 self.device, self.gpu_id, distributed=self.tp_size > 1
) )
if ( if self.use_mla_backend:
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
):
cell_size = ( cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers * self.model_config.num_hidden_layers
...@@ -751,10 +764,7 @@ class ModelRunner: ...@@ -751,10 +764,7 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker. # Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker assert self.is_draft_worker
if ( if self.use_mla_backend:
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
):
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
...@@ -825,14 +835,21 @@ class ModelRunner: ...@@ -825,14 +835,21 @@ class ModelRunner:
def init_attention_backend(self): def init_attention_backend(self):
"""Init attention kernel backend.""" """Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer": if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import ( if not self.use_mla_backend:
FlashInferAttnBackend, from sglang.srt.layers.attention.flashinfer_backend import (
) FlashInferAttnBackend,
)
# Init streams # Init streams
if self.server_args.speculative_algorithm == "EAGLE": if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream() self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self) self.attn_backend = FlashInferAttnBackend(self)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "triton": elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, ( assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. " "Window attention is not supported in the triton attention backend. "
...@@ -858,12 +875,6 @@ class ModelRunner: ...@@ -858,12 +875,6 @@ class ModelRunner:
) )
self.attn_backend = TorchNativeAttnBackend(self) self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "flashmla": elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
......
...@@ -686,7 +686,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -686,7 +686,6 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc = None self.w_vc = None
self.w_scale = None self.w_scale = None
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
self.flashinfer_mla_disable_ragged = global_server_args_dict[ self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged" "flashinfer_mla_disable_ragged"
] ]
...@@ -694,7 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -694,7 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool: def no_absorb(self, forward_batch: ForwardBatch) -> bool:
if self.enable_flashinfer_mla: if self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
return ( return (
not self.flashinfer_mla_disable_ragged not self.flashinfer_mla_disable_ragged
......
...@@ -179,7 +179,7 @@ class ServerArgs: ...@@ -179,7 +179,7 @@ class ServerArgs:
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0 hicache_ratio: float = 2.0
enable_flashinfer_mla: bool = False enable_flashinfer_mla: bool = False # TODO: remove this argument
enable_flashmla: bool = False enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
...@@ -267,15 +267,11 @@ class ServerArgs: ...@@ -267,15 +267,11 @@ class ServerArgs:
else: else:
self.cuda_graph_max_bs = 160 self.cuda_graph_max_bs = 160
# Choose kernel backends # Set kernel backends for hpu device
if self.device == "hpu": if self.device == "hpu":
self.attention_backend = "torch_native" self.attention_backend = "torch_native"
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
if self.attention_backend is None:
self.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
if self.sampling_backend is None: if self.sampling_backend is None:
self.sampling_backend = ( self.sampling_backend = (
"flashinfer" if is_flashinfer_available() else "pytorch" "flashinfer" if is_flashinfer_available() else "pytorch"
...@@ -842,7 +838,7 @@ class ServerArgs: ...@@ -842,7 +838,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--enable-flashinfer-mla", "--enable-flashinfer-mla",
action="store_true", action="store_true",
help="Enable FlashInfer MLA optimization", help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
) )
parser.add_argument( parser.add_argument(
"--enable-flashmla", "--enable-flashmla",
......
...@@ -11,7 +11,11 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group ...@@ -11,7 +11,11 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc from sglang.srt.managers.schedule_batch import (
ScheduleBatch,
get_last_loc,
global_server_args_dict,
)
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
...@@ -146,15 +150,26 @@ class EAGLEWorker(TpModelWorker): ...@@ -146,15 +150,26 @@ class EAGLEWorker(TpModelWorker):
def init_attention_backend(self): def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
if self.server_args.attention_backend == "flashinfer": if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import ( if not global_server_args_dict["use_mla_backend"]:
FlashInferMultiStepDraftBackend, from sglang.srt.layers.attention.flashinfer_backend import (
) FlashInferMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMultiStepDraftBackend( self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.draft_model_runner, self.draft_model_runner,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1 self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True self.has_prefill_wrapper_verify = True
...@@ -171,19 +186,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -171,19 +186,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_extend_attn_backend = None self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1 self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "fa3": elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import ( from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionMultiStepBackend, FlashAttentionMultiStepBackend,
......
...@@ -26,7 +26,8 @@ class TestFlashinferMLA(CustomTestCase): ...@@ -26,7 +26,8 @@ class TestFlashinferMLA(CustomTestCase):
"--enable-torch-compile", "--enable-torch-compile",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "2",
"--enable-flashinfer-mla", "--attention-backend",
"flashinfer",
] ]
) )
cls.process = popen_launch_server( cls.process = popen_launch_server(
...@@ -69,8 +70,8 @@ class TestFlashinferMLANoRagged(CustomTestCase): ...@@ -69,8 +70,8 @@ class TestFlashinferMLANoRagged(CustomTestCase):
"--disable-cuda-graph", "--disable-cuda-graph",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"4", "4",
"--enable-flashinfer-mla", "--attention-backend",
"--flashinfer-mla-disable-ragged", "flashinfer",
] ]
) )
cls.process = popen_launch_server( cls.process = popen_launch_server(
...@@ -125,7 +126,8 @@ class TestFlashinferMLAMTP(CustomTestCase): ...@@ -125,7 +126,8 @@ class TestFlashinferMLAMTP(CustomTestCase):
"1", "1",
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"4", "4",
"--enable-flashinfer-mla", "--attention-backend",
"flashinfer",
] ]
) )
cls.process = popen_launch_server( cls.process = popen_launch_server(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment