Unverified Commit 9fb48f95 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Support nextn for flashinfer mla attention backend (#4218)

parent 89ccb533
...@@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be ...@@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. - **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. Currently when using flashinfer mla wrapper and speculative decoding together, the `speculative_eagle_topk` parameter should be set to 1.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
......
...@@ -555,6 +555,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -555,6 +555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return ( return (
not global_server_args_dict["flashinfer_mla_disable_ragged"] not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend() and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0 and forward_batch.extend_prefix_lens.sum() == 0
) )
else: else:
......
...@@ -123,6 +123,16 @@ class EAGLEWorker(TpModelWorker): ...@@ -123,6 +123,16 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
else: else:
raise ValueError( raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
......
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
import torch import torch
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
...@@ -100,5 +101,67 @@ class TestFlashinferMLANoRagged(unittest.TestCase): ...@@ -100,5 +101,67 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.62) self.assertGreater(metrics["accuracy"], 0.62)
class TestFlashinferMLAMTP(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code"]
if torch.cuda.is_available() and torch.version.cuda:
other_args.extend(
[
"--cuda-graph-max-bs",
"2",
"--disable-radix",
"--enable-torch-compile",
"--torch-compile-max-bs",
"1",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
"lmsys/sglang-ci-dsv3-test-NextN",
"--speculative-num-steps",
"4",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
"--enable-flashinfer-mla",
]
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment