Unverified Commit 77225d60 authored by b8zhong's avatar b8zhong Committed by GitHub
Browse files

Use Flashinfer TRT-LLM as Llama 4 compatible MoE backend (#11928)

parent 9c6e25d2
...@@ -643,6 +643,7 @@ class ModelConfig: ...@@ -643,6 +643,7 @@ class ModelConfig:
"petit_nvfp4", "petit_nvfp4",
] ]
compatible_quantization_methods = { compatible_quantization_methods = {
"modelopt_fp8": ["modelopt"],
"modelopt_fp4": ["modelopt"], "modelopt_fp4": ["modelopt"],
"petit_nvfp4": ["modelopt"], "petit_nvfp4": ["modelopt"],
"w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
......
...@@ -25,6 +25,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -25,6 +25,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
cutlass_fp8_supported, cutlass_fp8_supported,
...@@ -468,8 +469,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -468,8 +469,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Fp8 moe kernel needs single weight scale for w13 per expert. # Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales then dequant and requant each expert. # We take the max of the w1 and w3 scales then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2) if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
# Get the maximum scale across w1 and w3 for each expert # Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
...@@ -517,6 +516,84 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -517,6 +516,84 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale.max(), requires_grad=False layer.w2_input_scale.max(), requires_grad=False
) )
# Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
if should_use_flashinfer_trtllm_moe():
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
# 1) Swap W13 halves: [Up, Gate] -> [Gate, Up] expected by FI
num_experts, two_n, hidden = layer.w13_weight.shape
inter = two_n // 2
w13_swapped = (
layer.w13_weight.reshape(num_experts, 2, inter, hidden)
.flip(dims=[1])
.reshape(num_experts, two_n, hidden)
)
# 2) Reorder rows for fused gated activation (W13)
w13_interleaved = [
reorder_rows_for_gated_act_gemm(w13_swapped[i])
for i in range(num_experts)
]
w13_interleaved = torch.stack(w13_interleaved).reshape(
num_experts, two_n, hidden
)
# 3) Shuffle weights for transposed MMA output (both W13, W2)
epilogue_tile_m = 128
w13_shuffled = [
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
for i in range(num_experts)
]
w2_shuffled = [
shuffle_matrix_a(layer.w2_weight[i].view(torch.uint8), epilogue_tile_m)
for i in range(num_experts)
]
layer.w13_weight = Parameter(
torch.stack(w13_shuffled).view(torch.float8_e4m3fn),
requires_grad=False,
)
layer.w2_weight = Parameter(
torch.stack(w2_shuffled).view(torch.float8_e4m3fn),
requires_grad=False,
)
# Precompute and register per-expert output scaling factors for FI MoE
if should_use_flashinfer_trtllm_moe():
# Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction
assert (
hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None
)
assert hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None
assert (
hasattr(layer, "w13_weight_scale")
and layer.w13_weight_scale is not None
)
assert (
hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None
)
input_scale = layer.w13_input_scale.to(torch.float32)
activation_scale = layer.w2_input_scale.to(torch.float32)
w13_weight_scale = layer.w13_weight_scale.to(torch.float32)
w2_weight_scale = layer.w2_weight_scale.to(torch.float32)
output1_scales_scalar = (
w13_weight_scale * input_scale * (1.0 / activation_scale)
)
output1_scales_gate_scalar = w13_weight_scale * input_scale
output2_scales_scalar = activation_scale * w2_weight_scale
layer.output1_scales_scalar = Parameter(
output1_scales_scalar, requires_grad=False
)
layer.output1_scales_gate_scalar = Parameter(
output1_scales_gate_scalar, requires_grad=False
)
layer.output2_scales_scalar = Parameter(
output2_scales_scalar, requires_grad=False
)
def create_moe_runner( def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
): ):
...@@ -528,6 +605,81 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -528,6 +605,81 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput, dispatch_output: StandardDispatchOutput,
) -> CombineInput: ) -> CombineInput:
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
# Fast path: TRT-LLM FP8 per-tensor MoE using BYPASSED TopK routing
from sglang.srt.layers.moe.topk import TopKOutputChecker
if should_use_flashinfer_trtllm_moe() and TopKOutputChecker.format_is_bypassed(
topk_output
):
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
# Constraints
assert (
self.moe_runner_config.activation == "silu"
), "Only silu is supported for flashinfer fp8 moe"
from flashinfer import RoutingMethodType
from flashinfer.fused_moe import trtllm_fp8_per_tensor_scale_moe
correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias
)
# Pre-quantize activations to FP8 per-tensor using provided input scale
x_fp8, _ = scaled_fp8_quant(x, layer.w13_input_scale)
use_routing_scales_on_input = True
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
# Enforce Llama4 routing for ModelOpt FP8 MoE for now.
# TODO(brayden): support other routing methods
assert topk_config.top_k == 1, "ModelOpt FP8 MoE requires top_k==1"
assert (
not topk_config.num_expert_group
), "ModelOpt FP8 MoE does not support expert grouping"
assert (
not topk_config.topk_group
), "ModelOpt FP8 MoE does not support grouped top-k"
routing_method_type = RoutingMethodType.Llama4
# FlashInfer TRTLLM requires routing_logits (and bias) to be bfloat16
routing_logits_cast = router_logits.to(torch.bfloat16)
routing_bias_cast = (
None if correction_bias is None else correction_bias.to(torch.bfloat16)
)
output = trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits_cast,
routing_bias=routing_bias_cast,
hidden_states=x_fp8,
gemm1_weights=layer.w13_weight,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
gemm2_weights=layer.w2_weight,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
use_routing_scales_on_input=use_routing_scales_on_input,
tile_tokens_dim=None,
routing_method_type=routing_method_type,
)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo( quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight, w13_weight=layer.w13_weight,
...@@ -1384,8 +1536,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1384,8 +1536,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
alt_stream=None, alt_stream=None,
) -> CombineInput: ) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output topk_output = dispatch_output.topk_output
...@@ -1398,6 +1548,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1398,6 +1548,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# Check if this is a FlashInferFP4MoE layer that should handle its own forward # Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"): if hasattr(layer, "gemm1_weights_fp4_shuffled"):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward # This layer was processed with flashinfer TRTLLM - delegate to its own forward
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
return StandardCombineInput(hidden_states=layer.forward(x, topk_output)) return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
if self.enable_flashinfer_cutlass_moe: if self.enable_flashinfer_cutlass_moe:
...@@ -1466,6 +1618,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1466,6 +1618,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if forward_shared_experts is not None: if forward_shared_experts is not None:
torch.cuda.current_stream().wait_stream(alt_stream) torch.cuda.current_stream().wait_stream(alt_stream)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
...@@ -1487,6 +1641,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1487,6 +1641,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype) ).to(x.dtype)
# Scale by routed_scaling_factor is fused into select_experts. # Scale by routed_scaling_factor is fused into select_experts.
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
def apply_without_routing_weights( def apply_without_routing_weights(
......
...@@ -238,7 +238,7 @@ def get_quant_config( ...@@ -238,7 +238,7 @@ def get_quant_config(
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization.startswith("modelopt") and ( elif model_config.quantization.startswith("modelopt") and (
config["producer"]["name"].startswith("modelopt") config.get("producer", {}).get("name", "").startswith("modelopt")
): ):
quant_algo = config["quantization"]["quant_algo"] quant_algo = config["quantization"]["quant_algo"]
if quant_algo is None: if quant_algo is None:
......
...@@ -971,6 +971,11 @@ class ServerArgs: ...@@ -971,6 +971,11 @@ class ServerArgs:
logger.warning( logger.warning(
"Use trtllm_mha as attention backend on sm100 for Llama4 model" "Use trtllm_mha as attention backend on sm100 for Llama4 model"
) )
if is_sm100_supported() and self.moe_runner_backend == "auto":
self.moe_runner_backend = "flashinfer_trtllm"
logger.info(
"Use flashinfer_trtllm as MoE runner backend on SM100 for Llama4"
)
elif model_arch in [ elif model_arch in [
"Gemma2ForCausalLM", "Gemma2ForCausalLM",
"Gemma3ForCausalLM", "Gemma3ForCausalLM",
...@@ -1336,8 +1341,10 @@ class ServerArgs: ...@@ -1336,8 +1341,10 @@ class ServerArgs:
if self.moe_runner_backend == "flashinfer_trtllm": if self.moe_runner_backend == "flashinfer_trtllm":
assert ( assert (
self.quantization == "modelopt_fp4" or self.quantization == "fp8" self.quantization == "modelopt_fp4"
), "modelopt_fp4 or fp8 quantization is required for Flashinfer TRTLLM MoE" or self.quantization == "modelopt_fp8"
or self.quantization == "fp8"
), "modelopt_fp4, modelopt_fp8 or fp8 quantization is required for Flashinfer TRTLLM MoE"
self.disable_shared_experts_fusion = True self.disable_shared_experts_fusion = True
logger.warning( logger.warning(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment