"examples/vscode:/vscode.git/clone" did not exist on "fd02aad4029e7bbe4f49d06847ad1cded34d9eb2"
Unverified Commit f80371ff authored by b8zhong's avatar b8zhong Committed by GitHub
Browse files

Use flashinfer_trtllm moe runner backend to gain around 10% perf on b200 fp8 dpsk (#11816)

parent 62eff37b
......@@ -232,7 +232,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method, ModelOptNvFp4FusedMoEMethod
) or (
isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method.use_cutlass_fused_experts_fp8
and self.quant_method._should_use_cutlass_fused_experts()
)
def _load_per_tensor_weight_scale(
......
......@@ -33,6 +33,7 @@ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
......@@ -525,12 +526,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.use_cutlass_fused_experts_fp8 = (
get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported
and self.block_quant
and (is_sm100_supported() or is_sm90_supported())
)
def create_weights(
self,
......@@ -638,58 +633,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
if self.use_cutlass_fused_experts_fp8:
self.ab_strides1 = torch.full(
(num_experts,),
hidden_size,
device=w13_weight.device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(num_experts,),
2 * intermediate_size_per_partition,
device=w13_weight.device,
dtype=torch.int64,
)
self.ab_strides2 = torch.full(
(num_experts,),
intermediate_size_per_partition,
device=w2_weight.device,
dtype=torch.int64,
)
self.c_strides2 = torch.full(
(num_experts,),
hidden_size,
device=w2_weight.device,
dtype=torch.int64,
)
self.workspace = torch.empty(
90000, device=w13_weight.device, dtype=torch.uint8
)
self.a_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.b_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.out_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.a_scales_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.b_scales_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.expert_offsets = torch.empty(
num_experts + 1, device=w13_weight.device, dtype=torch.int32
)
self.problem_sizes1 = torch.empty(
num_experts, 3, device=w13_weight.device, dtype=torch.int32
)
self.problem_sizes2 = torch.empty(
num_experts, 3, device=w13_weight.device, dtype=torch.int32
)
if self._should_use_cutlass_fused_experts():
self._ensure_cutlass_buffers_initialized(layer)
else:
# Allocate 2 scales for w1 and w3 respectively.
......@@ -1079,7 +1024,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if ret is not None:
return StandardCombineInput(hidden_states=ret)
if self.use_cutlass_fused_experts_fp8:
if self._should_use_cutlass_fused_experts():
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
topk_weights, topk_ids, _ = topk_output
......@@ -1171,6 +1116,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return self.runner.run(dispatch_output, quant_info)
def _should_use_cutlass_fused_experts(self) -> bool:
"""Decide whether to use Cutlass FP8 fused-experts path based on moe runner backend,
with env var override via `SGLANG_CUTLASS_MOE`.
"""
backend = get_moe_runner_backend()
env_force = get_bool_env_var("SGLANG_CUTLASS_MOE")
# TODO: remove env var in the future, it should be handled by moe runner backend
if env_force:
return True
return (
backend.is_flashinfer_cutlass()
and self.cutlass_fp8_supported
and self.block_quant
and (is_sm100_supported() or is_sm90_supported())
)
def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None:
if getattr(self, "_cutlass_buffers_ready", False):
return
device = layer.w13_weight.device
num_experts = layer.w13_weight.shape[0]
hidden_size = layer.w2_weight.shape[1]
intermediate_size_per_partition = layer.intermediate_size_per_partition
self.ab_strides1 = torch.full(
(num_experts,), hidden_size, device=device, dtype=torch.int64
)
self.c_strides1 = torch.full(
(num_experts,),
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.ab_strides2 = torch.full(
(num_experts,),
intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.c_strides2 = torch.full(
(num_experts,), hidden_size, device=device, dtype=torch.int64
)
self.workspace = torch.empty(90000, device=device, dtype=torch.uint8)
self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
self.a_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
self.b_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
self.expert_offsets = torch.empty(
num_experts + 1, device=device, dtype=torch.int32
)
self.problem_sizes1 = torch.empty(
num_experts, 3, device=device, dtype=torch.int32
)
self.problem_sizes2 = torch.empty(
num_experts, 3, device=device, dtype=torch.int32
)
self._cutlass_buffers_ready = True
def apply_with_router_logits(
self,
layer: torch.nn.Module,
......
......@@ -892,14 +892,19 @@ class ServerArgs:
logger.info(
"Enable FlashInfer AllReduce Fusion on sm100 for DeepseekV3ForCausalLM"
)
if (
self.quantization == "modelopt_fp4"
and self.moe_runner_backend == "auto"
):
if self.moe_runner_backend == "auto":
self.moe_runner_backend = "flashinfer_trtllm"
logger.info(
"Use flashinfer_trtllm as moe runner backend on sm100 for DeepseekV3ForCausalLM"
"Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM"
)
if self.quantization is None:
# Default DeepSeek V3/R1 native FP8 when not explicitly set,
# Because we need this condition for an assertion in
# flashinfer_trtllm MoE runner backend.
self.quantization = "fp8"
logger.info(
"Quantization not specified, default to fp8 for DeepSeek on sm100"
)
elif model_arch in ["GptOssForCausalLM"]:
if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment