Unverified Commit 8374a96e authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

piecewise cuda graph support qwen3-moe (#11845)

parent 74de76c6
......@@ -212,6 +212,10 @@ class LayerCommunicator:
)
)
self._speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
def prepare_attn(
self,
hidden_states: torch.Tensor,
......@@ -315,13 +319,10 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch
) -> bool:
speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
if (
is_dp_attention_enabled()
and speculative_algo is not None
and speculative_algo.is_eagle()
and self._speculative_algo is not None
and self._speculative_algo.is_eagle()
):
return False
......
......@@ -1831,3 +1831,21 @@ def triton_scaled_mm(
)
return result.to(out_dtype)
if _is_cuda:
if enable_sgl_per_token_group_quant_8bit:
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_8bit")
def _(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
):
return
else:
@torch.library.register_fake("sgl_kernel::sgl_per_token_group_quant_fp8")
def _(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
):
return
......@@ -17,6 +17,7 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging
from contextlib import nullcontext
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
......@@ -590,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
if residual is not None
else hidden_states
)
with get_global_expert_distribution_recorder().with_current_layer(i):
ctx = (
nullcontext()
if get_global_server_args().enable_piecewise_cuda_graph
else get_global_expert_distribution_recorder().with_current_layer(i)
)
with ctx:
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
......
......@@ -55,5 +55,45 @@ class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
self.assertLess(prefill_latency, 0.015)
class TestPiecewiseCudaGraphQwen3MoE(CustomTestCase):
"""Test piecewise CUDA graph with Qwen3-Coder-30B-A3B-Instruct MoE model"""
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-piecewise-cuda-graph",
"--piecewise-cuda-graph-compiler",
"eager",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k_accuracy(self):
"""Test GSM8K accuracy with 8-shot setting"""
num_examples = 2000
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=num_examples,
num_threads=min(num_examples, 1024),
)
metrics = run_eval(args)
print(f"GSM8K Accuracy: {metrics['score']:.3f}")
self.assertGreaterEqual(metrics["score"], 0.90)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment