Unverified Commit 8374a96e authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

piecewise cuda graph support qwen3-moe (#11845)

parent 74de76c6
...@@ -212,6 +212,10 @@ class LayerCommunicator: ...@@ -212,6 +212,10 @@ class LayerCommunicator:
) )
) )
self._speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
def prepare_attn( def prepare_attn(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -315,13 +319,10 @@ class LayerCommunicator: ...@@ -315,13 +319,10 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer( def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
) -> bool: ) -> bool:
speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
if ( if (
is_dp_attention_enabled() is_dp_attention_enabled()
and speculative_algo is not None and self._speculative_algo is not None
and speculative_algo.is_eagle() and self._speculative_algo.is_eagle()
): ):
return False return False
......
...@@ -1831,3 +1831,21 @@ def triton_scaled_mm( ...@@ -1831,3 +1831,21 @@ def triton_scaled_mm(
) )
return result.to(out_dtype) return result.to(out_dtype)
if _is_cuda:
if enable_sgl_per_token_group_quant_8bit:
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_8bit")
def _(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
):
return
else:
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_fp8")
def _(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
):
return
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging import logging
from contextlib import nullcontext
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
...@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module): ...@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
if residual is not None if residual is not None
else hidden_states else hidden_states
) )
with get_global_expert_distribution_recorder().with_current_layer(i): ctx = (
nullcontext()
if get_global_server_args().enable_piecewise_cuda_graph
else get_global_expert_distribution_recorder().with_current_layer(i)
)
with ctx:
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
......
...@@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase): ...@@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
self.assertLess(prefill_latency, 0.015) self.assertLess(prefill_latency, 0.015)
class TestPiecewiseCudaGraphQwen3MoE(CustomTestCase):
"""Test piecewise CUDA graph with Qwen3-Coder-30B-A3B-Instruct MoE model"""
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-piecewise-cuda-graph",
"--piecewise-cuda-graph-compiler",
"eager",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k_accuracy(self):
"""Test GSM8K accuracy with 8-shot setting"""
num_examples = 2000
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=num_examples,
num_threads=min(num_examples, 1024),
)
metrics = run_eval(args)
print(f"GSM8K Accuracy: {metrics['score']:.3f}")
self.assertGreaterEqual(metrics["score"], 0.90)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment