Unverified Commit 74885a84 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "Replace enable_flashinfer_mla argument with attention_backend" (#5048)

parent b8b6008f
...@@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma
## Kernel backend ## Kernel backend
* `attention_backend`: The backend for attention computation and KV cache management, and can be one of `fa3`, `flashinfer`, `triton` or `torch_native`. When deploying deepseek models, this argument is for specifying the MLA backend it uses. * `attention_backend`: The backend for attention computation and KV cache management.
* `sampling_backend`: The backend for sampling. * `sampling_backend`: The backend for sampling.
## Constrained Decoding ## Constrained Decoding
...@@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. **This argument will be deprecated soon! Please use `--attention_backend flashinfer` instead for switching on flashfiner mla!** * `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden.
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when flashinfer is used as mla backend turned on. * `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
...@@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be ...@@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument. - **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
...@@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec ...@@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec
``` ```
- The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script. - The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script.
- The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. - The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
- Currently when using flashinfer mla wrapper (`--attention-backend flashinfer`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. The MTP feature on FlashAttention 3 backend is still under beta. - Currently when using flashinfer mla wrapper (`--enable-flashinfer-mla`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`.
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
......
...@@ -71,6 +71,8 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -71,6 +71,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.device = model_runner.device self.device = model_runner.device
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
global_config.enable_flashinfer_mla = True
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_workspace_buffer
if global_workspace_buffer is None: if global_workspace_buffer is None:
......
...@@ -76,6 +76,7 @@ global_server_args_dict = { ...@@ -76,6 +76,7 @@ global_server_args_dict = {
"device": ServerArgs.device, "device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"enable_flashmla": ServerArgs.enable_flashmla, "enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
...@@ -1434,7 +1435,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1434,7 +1435,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Create seq_lens_cpu when needed # Create seq_lens_cpu when needed
if ( if (
global_server_args_dict["attention_backend"] == "flashinfer_mla" global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"] or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3" or global_server_args_dict["attention_backend"] == "fa3"
): ):
......
...@@ -151,6 +151,7 @@ class ModelRunner: ...@@ -151,6 +151,7 @@ class ModelRunner:
"device": server_args.device, "device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"enable_flashmla": server_args.enable_flashmla, "enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache, "disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
...@@ -222,14 +223,10 @@ class ModelRunner: ...@@ -222,14 +223,10 @@ class ModelRunner:
): ):
# TODO: add MLA optimization on CPU # TODO: add MLA optimization on CPU
if server_args.device != "cpu": if server_args.device != "cpu":
if ( if server_args.enable_flashinfer_mla:
server_args.attention_backend == "flashinfer"
or server_args.enable_flashinfer_mla
):
logger.info( logger.info(
"MLA optimization is turned on. Use flashinfer backend." "MLA optimization is turned on. Use flashinfer mla backend."
) )
# Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend
server_args.attention_backend = "flashinfer_mla" server_args.attention_backend = "flashinfer_mla"
elif server_args.enable_flashmla: elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.") logger.info("MLA optimization is turned on. Use flashmla decode.")
......
...@@ -684,6 +684,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -684,6 +684,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc = None self.w_vc = None
self.w_scale = None self.w_scale = None
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
self.flashinfer_mla_disable_ragged = global_server_args_dict[ self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged" "flashinfer_mla_disable_ragged"
] ]
...@@ -691,7 +692,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -691,7 +692,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool: def no_absorb(self, forward_batch: ForwardBatch) -> bool:
if self.attention_backend == "flashinfer_mla": if self.enable_flashinfer_mla:
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
return ( return (
not self.flashinfer_mla_disable_ragged not self.flashinfer_mla_disable_ragged
......
...@@ -179,7 +179,7 @@ class ServerArgs: ...@@ -179,7 +179,7 @@ class ServerArgs:
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0 hicache_ratio: float = 2.0
enable_flashinfer_mla: bool = False # TODO: remove this argument enable_flashinfer_mla: bool = False
enable_flashmla: bool = False enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
...@@ -836,7 +836,7 @@ class ServerArgs: ...@@ -836,7 +836,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--enable-flashinfer-mla", "--enable-flashinfer-mla",
action="store_true", action="store_true",
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!", help="Enable FlashInfer MLA optimization",
) )
parser.add_argument( parser.add_argument(
"--enable-flashmla", "--enable-flashmla",
......
...@@ -26,8 +26,7 @@ class TestFlashinferMLA(CustomTestCase): ...@@ -26,8 +26,7 @@ class TestFlashinferMLA(CustomTestCase):
"--enable-torch-compile", "--enable-torch-compile",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "2",
"--attention-backend", "--enable-flashinfer-mla",
"flashinfer",
] ]
) )
cls.process = popen_launch_server( cls.process = popen_launch_server(
...@@ -70,8 +69,8 @@ class TestFlashinferMLANoRagged(CustomTestCase): ...@@ -70,8 +69,8 @@ class TestFlashinferMLANoRagged(CustomTestCase):
"--disable-cuda-graph", "--disable-cuda-graph",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"4", "4",
"--attention-backend", "--enable-flashinfer-mla",
"flashinfer", "--flashinfer-mla-disable-ragged",
] ]
) )
cls.process = popen_launch_server( cls.process = popen_launch_server(
...@@ -126,8 +125,7 @@ class TestFlashinferMLAMTP(CustomTestCase): ...@@ -126,8 +125,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
"1", "1",
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"4", "4",
"--attention-backend", "--enable-flashinfer-mla",
"flashinfer",
] ]
) )
cls.process = popen_launch_server( cls.process = popen_launch_server(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment