Unverified Commit a91e90d9 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[2/2] Fuse routed scaling factor into select_experts (#8690)

parent f96413c4
...@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -923,6 +924,12 @@ class FusedMoE(torch.nn.Module): ...@@ -923,6 +924,12 @@ class FusedMoE(torch.nn.Module):
for shard_id in ["w1", "w2", "w3"] for shard_id in ["w1", "w2", "w3"]
] ]
def should_fuse_routed_scaling_factor_in_topk(self):
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
isinstance(self.quant_method, Fp8MoEMethod)
and self.quant_method.use_cutlass_fused_experts_fp8
)
class FlashInferFusedMoE(FusedMoE): class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -197,6 +197,7 @@ class TopK(CustomOp): ...@@ -197,6 +197,7 @@ class TopK(CustomOp):
scoring_func: str = "softmax", scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
): ):
# NOTE: scoring_func is not used for now, but we keep it for future use # NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details # see https://github.com/sgl-project/sglang/pull/4505 for more details
...@@ -215,6 +216,7 @@ class TopK(CustomOp): ...@@ -215,6 +216,7 @@ class TopK(CustomOp):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
correction_bias=correction_bias, correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
...@@ -433,6 +435,7 @@ def grouped_topk_gpu( ...@@ -433,6 +435,7 @@ def grouped_topk_gpu(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -480,6 +483,8 @@ def grouped_topk_gpu( ...@@ -480,6 +483,8 @@ def grouped_topk_gpu(
else topk_weights[:, :-1].sum(dim=-1, keepdim=True) else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
) )
topk_weights = topk_weights / topk_weights_sum topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
...@@ -528,6 +533,7 @@ def biased_grouped_topk_impl( ...@@ -528,6 +533,7 @@ def biased_grouped_topk_impl(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -579,6 +585,8 @@ def biased_grouped_topk_impl( ...@@ -579,6 +585,8 @@ def biased_grouped_topk_impl(
else topk_weights[:, :-1].sum(dim=-1, keepdim=True) else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
) )
topk_weights = topk_weights / topk_weights_sum topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
...@@ -621,6 +629,7 @@ def biased_grouped_topk_gpu( ...@@ -621,6 +629,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
): ):
assert ( assert (
routed_scaling_factor is not None routed_scaling_factor is not None
...@@ -640,6 +649,7 @@ def biased_grouped_topk_gpu( ...@@ -640,6 +649,7 @@ def biased_grouped_topk_gpu(
topk, topk,
num_fused_shared_experts, num_fused_shared_experts,
routed_scaling_factor, routed_scaling_factor,
apply_routed_scaling_factor_on_output,
) )
# TODO merge into kernel # TODO merge into kernel
if (expert_location_dispatch_info is not None) or ( if (expert_location_dispatch_info is not None) or (
...@@ -650,6 +660,7 @@ def biased_grouped_topk_gpu( ...@@ -650,6 +660,7 @@ def biased_grouped_topk_gpu(
) )
return topk_weights, topk_ids return topk_weights, topk_ids
elif _use_aiter: elif _use_aiter:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
token = gating_output.shape[0] token = gating_output.shape[0]
device = gating_output.device device = gating_output.device
assert ( assert (
...@@ -681,6 +692,7 @@ def biased_grouped_topk_gpu( ...@@ -681,6 +692,7 @@ def biased_grouped_topk_gpu(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
...@@ -743,6 +755,9 @@ def select_experts( ...@@ -743,6 +755,9 @@ def select_experts(
correction_bias = topk_config.correction_bias correction_bias = topk_config.correction_bias
torch_native = topk_config.torch_native torch_native = topk_config.torch_native
routed_scaling_factor = topk_config.routed_scaling_factor routed_scaling_factor = topk_config.routed_scaling_factor
apply_routed_scaling_factor_on_output = (
topk_config.apply_routed_scaling_factor_on_output
)
router_logits, correction_bias = ( router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs( expert_location_dispatch.transform_select_experts_inputs(
...@@ -768,6 +783,7 @@ def select_experts( ...@@ -768,6 +783,7 @@ def select_experts(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
else: else:
topk_weights, topk_ids = biased_grouped_topk( topk_weights, topk_ids = biased_grouped_topk(
...@@ -782,12 +798,14 @@ def select_experts( ...@@ -782,12 +798,14 @@ def select_experts(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
elif torch_native and custom_routing_function is None: elif torch_native and custom_routing_function is None:
assert ( assert (
num_token_non_padded is None num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk_native" ), "num_token_non_padded is not yet supported in fused_topk_native"
assert expert_location_dispatch_info is None assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = fused_topk_native( topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
...@@ -795,6 +813,7 @@ def select_experts( ...@@ -795,6 +813,7 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
) )
elif custom_routing_function is None: elif custom_routing_function is None:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
# Qwen3MOE uses fused_topk # Qwen3MOE uses fused_topk
topk_weights, topk_ids = fused_topk( topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -809,6 +828,7 @@ def select_experts( ...@@ -809,6 +828,7 @@ def select_experts(
num_token_non_padded is None num_token_non_padded is None
), "num_token_non_padded is not yet supported in custom_routing_function" ), "num_token_non_padded is not yet supported in custom_routing_function"
assert expert_location_dispatch_info is None assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = custom_routing_function( topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
......
...@@ -514,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -514,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
self.use_cutlass_fused_experts_fp8 = (
get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported
and self.block_quant
and (is_sm100_supported() or is_sm90_supported())
)
def create_weights( def create_weights(
self, self,
...@@ -1021,12 +1027,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1021,12 +1027,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if ret is not None: if ret is not None:
return ret return ret
if ( if self.use_cutlass_fused_experts_fp8:
get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported
and self.block_quant
and (is_sm100_supported() or is_sm90_supported())
):
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
...@@ -1053,9 +1054,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1053,9 +1054,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.problem_sizes2, self.problem_sizes2,
use_fp8_blockscale=True, use_fp8_blockscale=True,
) )
# TODO: Fuse into select_experts # Scale by routed_scaling_factor is fused into select_experts.
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output return output
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
......
...@@ -1305,8 +1305,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1305,8 +1305,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_rank=layer.moe_tp_rank, tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]), tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0] )[0]
if moe_runner_config.routed_scaling_factor is not None: # Scale by routed_scaling_factor is fused into select_experts.
output *= moe_runner_config.routed_scaling_factor
if should_use_flashinfer_cutlass_moe_fp4_allgather(): if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output output, global_output = get_local_dp_buffer(), output
get_tp_group().reduce_scatterv( get_tp_group().reduce_scatterv(
...@@ -1332,6 +1331,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1332,6 +1331,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
params=layer.cutlass_moe_params, params=layer.cutlass_moe_params,
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype) ).to(x.dtype)
if moe_runner_config.routed_scaling_factor is not None: # Scale by routed_scaling_factor is fused into select_experts.
output *= moe_runner_config.routed_scaling_factor
return output return output
...@@ -319,17 +319,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -319,17 +319,6 @@ class DeepseekV2MoE(nn.Module):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
) )
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
...@@ -344,6 +333,18 @@ class DeepseekV2MoE(nn.Module): ...@@ -344,6 +333,18 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
)
self.shared_experts_is_int8 = False self.shared_experts_is_int8 = False
self.shared_experts_is_fp8 = False self.shared_experts_is_fp8 = False
self.shared_experts_weight_block_size = None self.shared_experts_weight_block_size = None
......
...@@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk ...@@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
], ],
) )
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2]) @pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): @pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True])
def test_moe_fused_gate_combined(
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
):
num_experts, num_expert_group, topk_group, topk = params num_experts, num_expert_group, topk_group, topk = params
dtype = torch.float32 dtype = torch.float32
...@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): ...@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk=topk, topk=topk,
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5, routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
ref_output, ref_indices = biased_grouped_topk( ref_output, ref_indices = biased_grouped_topk(
scores, scores,
...@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): ...@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
topk_group=topk_group, topk_group=topk_group,
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5, routed_scaling_factor=2.5,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
) )
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension # When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment