Unverified Commit 29589512 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

parent 584e1ab2
......@@ -9,6 +9,7 @@ import torch
from torch import nn
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
......@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
raise NotImplementedError
......
......@@ -3,7 +3,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn import Module
......@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
moe_runner_config=moe_runner_config,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
......@@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
......@@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
......@@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
**kwargs,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
topk_weights, topk_ids, router_logits = topk_output
......
......@@ -41,6 +41,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
logger = logging.getLogger(__name__)
......@@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase):
return out
class MxFp4MoEMethod:
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
class MxFp4MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Mxfp4Config):
self.quant_config = quant_config
@staticmethod
def get_moe_method(
......@@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
......@@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)
......@@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
......@@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)
......
......@@ -79,6 +79,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
......@@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
......@@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer,
x,
topk_output,
activation,
no_combine,
moe_runner_config.activation,
moe_runner_config.no_combine,
)
if ret is not None:
return ret
......@@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
use_fp8_blockscale=True,
)
# TODO: Fuse into select_experts
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
# Expert fusion with FP8 quantization
return fused_experts(
......@@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
......@@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def apply_with_router_logits(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
*,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
activation = moe_runner_config.activation
routed_scaling_factor = moe_runner_config.routed_scaling_factor
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_bypassed(topk_output)
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
assert (
activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
......@@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts,
top_k=layer.top_k,
n_group=layer.num_expert_group,
topk_group=layer.topk_group,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
......
......@@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
return weight, weight_scale, input_scale
# TODO(ch-wan): define these backends in --moe-runner-backend
def cutlass_block_fp8_supported() -> bool:
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
return False
......
......@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.utils import is_cuda
......@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
**kwargs,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
# Delay the import to avoid circular dependency
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16
orig_dtype = x.dtype
......
......@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
if TYPE_CHECKING:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
try:
from vllm import _custom_ops as ops
......@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
supports_router_weight = not layer.apply_router_weight_on_input
supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu"
supports_activation = layer.moe_runner_config.activation == "silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
......
......@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
......@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda():
......@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)
......@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@property
def enable_flashinfer_cutlass_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend
"""Access the global enable_flashinfer_cutlass_moe setting."""
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
return get_moe_runner_backend().is_flashinfer_cutlass()
def create_weights(
self,
......@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
......@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if self.enable_flashinfer_cutlass_moe:
assert (
not apply_router_weight_on_input
not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
......@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w2_blockscale_swizzled.view(torch.int32),
layer.g2_alphas,
],
ep_size=ep_size,
ep_rank=ep_rank,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=layer.moe_ep_size,
ep_rank=layer.moe_ep_rank,
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0]
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
......@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=apply_router_weight_on_input,
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
......@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
......@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
......@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.w13_qweight,
layer.w2_qweight,
topk_output=topk_output,
inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
moe_runner_config=moe_runner_config,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales,
......@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
@staticmethod
......@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
)
if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
tp_rank
]
tensor = loaded_weight.view(
layer.moe_tp_size, -1, loaded_weight.size(1)
)[tp_rank]
if shard_id == "w1":
param.data[expert_id, : shard_size // 2] = tensor
else:
param.data[expert_id, shard_size // 2 :] = tensor
elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1
loaded_weight.size(0), layer.moe_tp_size, -1
)[:, tp_rank]
else:
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
......
......@@ -16,14 +16,13 @@
from __future__ import annotations
import importlib.util
import logging
from typing import TYPE_CHECKING, List, Optional
import torch
import triton.language as tl
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
......@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
......@@ -60,6 +58,7 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
OCP_MX_BLOCK_SIZE = 32
......@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self,
prefix: str,
):
from sglang.srt.managers.schedule_batch import global_server_args_dict
super().__init__()
self.prefix = prefix
self.topk_indices_dtype = None
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
......@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logger,
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
)
# TODO: these values are hardcoded for now, we need to get them from the model
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
......@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
......@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
activation=activation,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
moe_runner_config=moe_runner_config,
)
else:
return self.triton_kernel_moe_forward(
......@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
from __future__ import annotations
import importlib
from typing import TYPE_CHECKING, Callable, List, Optional
import importlib.util
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn.functional as F
......@@ -24,7 +24,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
......@@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
kwargs = {}
if activation_alpha is not None:
kwargs["activation_alpha"] = activation_alpha
if swiglu_limit is not None:
kwargs["swiglu_limit"] = swiglu_limit
return self.forward(
x=x,
layer=layer,
topk_output=topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
**kwargs,
moe_runner_config=moe_runner_config,
)
def forward_cuda(
......@@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if self.use_triton_kernels:
if self.with_bias:
assert self.triton_kernel_moe_with_bias_forward is not None
return self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=layer.w13_weight,
......@@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
activation=activation,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
moe_runner_config=moe_runner_config,
w1_pcg=None,
w2_pcg=None,
)
else:
assert self.triton_kernel_moe_forward is not None
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
else:
if _use_aiter:
assert not no_combine, "unsupported"
assert not moe_runner_config.no_combine, "unsupported"
topk_weights, topk_ids, _ = topk_output
if apply_router_weight_on_input:
if moe_runner_config.apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
......@@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids,
activation=(
ActivationType.Silu
if activation == "silu"
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
)
......@@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None),
topk_output=topk_output,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
moe_runner_config=moe_runner_config,
)
def forward_cpu(
......@@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
assert (
moe_runner_config.activation == "silu"
), f"activation = {moe_runner_config.activation} is not supported."
if (
use_intel_amx_backend(layer)
and not moe_runner_config.apply_router_weight_on_input
):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
......@@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer,
x,
topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
moe_runner_config,
)
def forward_npu(
......@@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
......@@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer,
x,
topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
moe_runner_config,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
......
......@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
self,
layer: EPMoE,
x: torch.Tensor,
topk_output: TopKOutput,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None,
**kwargs,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
# TODO(ch-wan): move it out of this class
......@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale,
layer.w2_input_scale,
)
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
......@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
_is_fp8_fnuz = is_fp8_fnuz()
......@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -49,6 +49,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
_is_cuda = is_cuda()
......@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
......@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
moe_runner_config=moe_runner_config,
use_int8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer,
x,
topk_output: TopKOutput,
**kwargs,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
......
......@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe import is_tbo_enabled
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
......@@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
"device",
"disable_chunked_prefix_cache",
"disable_radix_cache",
"enable_two_batch_overlap",
"tbo_token_distribution_threshold",
"enable_dp_lm_head",
"moe_a2a_backend",
"deepep_mode",
"enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"deepep_config",
"ep_num_redundant_experts",
"enable_nan_detection",
"flashinfer_mla_disable_ragged",
......@@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
"weight_loader_disable_mmap",
"enable_triton_kernel_moe",
"enable_flashinfer_mxfp4_moe",
"enable_multimodal",
"enable_symm_mem",
"quantization",
......
......@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
......@@ -245,6 +245,9 @@ class Scheduler(
)
)
# Init model config
self.model_config = ModelConfig.from_server_args(server_args)
# Init inter-process communication
context = zmq.Context(2)
self.idle_sleeper = None
......@@ -292,6 +295,9 @@ class Scheduler(
# Init tokenizer
self.init_tokenizer()
# Init moe config
self.init_moe_config()
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser(
......@@ -538,8 +544,6 @@ class Scheduler(
def init_tokenizer(self):
server_args = self.server_args
self.model_config = ModelConfig.from_server_args(server_args)
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init:
......@@ -761,6 +765,10 @@ class Scheduler(
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []
def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)
@DynamicGradMode()
def event_loop_normal(self):
"""A normal scheduler loop."""
......@@ -1823,11 +1831,6 @@ class Scheduler(
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=MoeA2ABackend(
self.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)
......@@ -1922,9 +1925,6 @@ class Scheduler(
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool,
disable_overlap_schedule: bool,
):
......@@ -1972,9 +1972,6 @@ class Scheduler(
is_extend_in_batch,
*tbo_preparer.prepare_all_gather(
local_batch,
deepep_mode,
enable_deepep_moe,
enable_two_batch_overlap,
),
],
dtype=torch.int64,
......
......@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.quantization import (
deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer,
......@@ -219,8 +218,6 @@ class ModelRunner:
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
}
)
......
......@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype)
self.topk = TopK(
self.top_k,
renormalize=True,
)
self.moe_runner_config = MoeRunnerConfig(inplace=True)
self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
......@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
topk_output,
self.moe_runner_config,
)
if self.tp_size > 1:
......@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(
......
......@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
w1=self.w1,
w2=self.w2,
topk_output=topk_output,
inplace=True,
moe_runner_config=MoeRunnerConfig(inplace=True),
)
if self.config.n_shared_experts is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment