Unverified Commit 6fb29ffd authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Deprecate enable-flashinfer-mla and enable-flashmla (#5480)

parent 4fb05583
......@@ -192,6 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.**
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.
......@@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), and [Triton](https://github.com/triton-lang/triton) backends. It can be set with `--attention-backend` argument.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
......
......@@ -76,7 +76,6 @@ global_server_args_dict = {
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
......@@ -1480,7 +1479,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_server_args_dict["use_mla_backend"]
and global_server_args_dict["attention_backend"] == "flashinfer"
)
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "fa3"
):
seq_lens_cpu = self.seq_lens.cpu()
......
......@@ -156,7 +156,6 @@ class ModelRunner:
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
......@@ -225,15 +224,7 @@ class ModelRunner:
def model_specific_adjustment(self):
server_args = self.server_args
if server_args.enable_flashinfer_mla:
# TODO: remove this branch after enable_flashinfer_mla is deprecated
logger.info("MLA optimization is turned on. Use flashinfer backend.")
server_args.attention_backend = "flashinfer"
elif server_args.enable_flashmla:
# TODO: remove this branch after enable_flashmla is deprecated
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend is None:
if server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
if not self.use_mla_backend:
if (
......@@ -259,7 +250,12 @@ class ModelRunner:
elif self.use_mla_backend:
# TODO: add MLA optimization on CPU
if server_args.device != "cpu":
if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
if server_args.attention_backend in [
"flashinfer",
"fa3",
"triton",
"flashmla",
]:
logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)
......
......@@ -179,8 +179,6 @@ class ServerArgs:
tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
enable_flashinfer_mla: bool = False # TODO: remove this argument
enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
n_share_experts_fusion: int = 0
......@@ -254,7 +252,7 @@ class ServerArgs:
assert self.chunked_prefill_size % self.page_size == 0
if self.enable_flashmla is True:
if self.attention_backend == "flashmla":
logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
)
......@@ -823,7 +821,7 @@ class ServerArgs:
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton", "torch_native", "fa3"],
choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
......@@ -843,13 +841,13 @@ class ServerArgs:
)
parser.add_argument(
"--enable-flashinfer-mla",
action="store_true",
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
action=DeprecatedAction,
help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
)
parser.add_argument(
"--enable-flashmla",
action="store_true",
help="Enable FlashMLA decode optimization",
action=DeprecatedAction,
help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
)
parser.add_argument(
"--flashinfer-mla-disable-ragged",
......
......@@ -176,16 +176,11 @@ def main(args, server_args):
]
)
if server_args.enable_flashinfer_mla:
if server_args.attention_backend:
other_args.extend(
[
"--enable-flashinfer-mla",
]
)
if server_args.enable_flashmla:
other_args.extend(
[
"--enable-flashmla",
"--attention-backend",
server_args.attention_backend,
]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment