Unverified Commit 29589512 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

parent 584e1ab2
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from torch import nn from torch import nn
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
...@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped ...@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): ...@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True, use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv), w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv), w2_scale=(layer.w2_weight_scale_inv),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs ...@@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig, CompressedTensorsConfig,
...@@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts from sglang.srt.layers.moe.fused_moe_triton import fused_experts
...@@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL, == QuantizationStrategy.CHANNEL,
...@@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
topk_weights, topk_ids, router_logits = topk_output topk_weights, topk_ids, router_logits = topk_output
......
...@@ -41,6 +41,7 @@ from sglang.srt.utils import ( ...@@ -41,6 +41,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase): ...@@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase):
return out return out
class MxFp4MoEMethod: class MxFp4MoEMethod(FusedMoEMethodBase):
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"): def __init__(self, quant_config: Mxfp4Config):
original_init = cls.__init__ self.quant_config = quant_config
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
...@@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod): ...@@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
...@@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod): ...@@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
activation=( activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
), ),
doweight_stage1=False, doweight_stage1=False,
) )
...@@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod): ...@@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
...@@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod): ...@@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
activation=( activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
), ),
doweight_stage1=False, doweight_stage1=False,
) )
......
...@@ -79,6 +79,7 @@ from sglang.srt.utils import ( ...@@ -79,6 +79,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
...@@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
...@@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer, layer,
x, x,
topk_output, topk_output,
activation, moe_runner_config.activation,
no_combine, moe_runner_config.no_combine,
) )
if ret is not None: if ret is not None:
return ret return ret
...@@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
use_fp8_blockscale=True, use_fp8_blockscale=True,
) )
# TODO: Fuse into select_experts # TODO: Fuse into select_experts
if routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
return output return output
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
...@@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace and not no_combine, moe_runner_config=moe_runner_config,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=( w1_scale=(
layer.w13_weight_scale_inv layer.w13_weight_scale_inv
...@@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
def apply_with_router_logits( def apply_with_router_logits(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
activation = moe_runner_config.activation
routed_scaling_factor = moe_runner_config.routed_scaling_factor
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_bypassed(topk_output)
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
assert ( assert (
activation == "silu" activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe" ), "Only silu is supported for flashinfer blockscale fp8 moe"
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
# NOTE: scales of hidden states have to be transposed! # NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous() a_sf_t = a_sf.t().contiguous()
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
return trtllm_fp8_block_scale_moe( return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32), routing_logits=router_logits.to(torch.float32),
...@@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
gemm2_weights=layer.w2_weight, gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv, gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts, num_experts=layer.num_experts,
top_k=layer.top_k, top_k=topk_config.top_k,
n_group=layer.num_expert_group, n_group=topk_config.num_expert_group,
topk_group=layer.topk_group, topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2], intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts, local_num_experts=layer.num_local_experts,
......
...@@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz( ...@@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
return weight, weight_scale, input_scale return weight, weight_scale, input_scale
# TODO(ch-wan): define these backends in --moe-runner-backend
def cutlass_block_fp8_supported() -> bool: def cutlass_block_fp8_supported() -> bool:
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"): if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
return False return False
......
...@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
...@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
assert activation == "silu", "Only SiLU activation is supported." assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
......
...@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda ...@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: ...@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
)[0] )[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin # apply_router_weight_on_input is not supported for moe marlin
supports_router_weight = not layer.apply_router_weight_on_input supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
# moe marlin requires the activation to be silu # moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu" supports_activation = layer.moe_runner_config.activation == "silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition) # down: (n, k) = (hidden_size, intermediate_size_per_partition)
......
...@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional ...@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_cuda, next_power_of_2 from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda(): if is_cuda():
...@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
no_combine=no_combine,
) )
...@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@property @property
def enable_flashinfer_cutlass_moe(self) -> bool: def enable_flashinfer_cutlass_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend
"""Access the global enable_flashinfer_cutlass_moe setting.""" """Access the global enable_flashinfer_cutlass_moe setting."""
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False) return get_moe_runner_backend().is_flashinfer_cutlass()
def create_weights( def create_weights(
self, self,
...@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward # Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"): if hasattr(layer, "gemm1_weights_fp4_shuffled"):
...@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if self.enable_flashinfer_cutlass_moe: if self.enable_flashinfer_cutlass_moe:
assert ( assert (
not apply_router_weight_on_input not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer" ), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint # and fp4 quantized weights loaded from the checkpoint
...@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w2_blockscale_swizzled.view(torch.int32), layer.w2_blockscale_swizzled.view(torch.int32),
layer.g2_alphas, layer.g2_alphas,
], ],
ep_size=ep_size, ep_size=layer.moe_ep_size,
ep_rank=ep_rank, ep_rank=layer.moe_ep_rank,
tp_size=tp_size, tp_size=layer.moe_tp_size,
tp_rank=tp_rank, tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]), tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0] )[0]
if routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
return output return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
...@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
params=layer.cutlass_moe_params, params=layer.cutlass_moe_params,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype) ).to(x.dtype)
if routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
return output return output
...@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs ...@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
...@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# avoid circular import # avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported." assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
...@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales, w1_scale=layer.w13_scales,
...@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
w1_zp=layer.w13_qzeros if has_zp else None, w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size], block_shape=[0, layer.group_size],
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
@staticmethod @staticmethod
...@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
) )
if "w13_qzeros" in weight_name: if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[ tensor = loaded_weight.view(
tp_rank layer.moe_tp_size, -1, loaded_weight.size(1)
] )[tp_rank]
if shard_id == "w1": if shard_id == "w1":
param.data[expert_id, : shard_size // 2] = tensor param.data[expert_id, : shard_size // 2] = tensor
else: else:
param.data[expert_id, shard_size // 2 :] = tensor param.data[expert_id, shard_size // 2 :] = tensor
elif "w2_qzeros" in weight_name: elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view( param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1 loaded_weight.size(0), layer.moe_tp_size, -1
)[:, tp_rank] )[:, tp_rank]
else: else:
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
......
...@@ -16,14 +16,13 @@ ...@@ -16,14 +16,13 @@
from __future__ import annotations from __future__ import annotations
import importlib.util
import logging import logging
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import triton.language as tl
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
...@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
) )
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
...@@ -60,6 +58,7 @@ if is_flashinfer_available(): ...@@ -60,6 +58,7 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
OCP_MX_BLOCK_SIZE = 32 OCP_MX_BLOCK_SIZE = 32
...@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self, self,
prefix: str, prefix: str,
): ):
from sglang.srt.managers.schedule_batch import global_server_args_dict
super().__init__() super().__init__()
self.prefix = prefix self.prefix = prefix
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False self.with_bias = False
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"] self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.triton_kernel_moe_forward = None self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None self.triton_kernel_moe_with_bias_forward = None
...@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logger, logger,
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...", f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
) )
# TODO: these values are hardcoded for now, we need to get them from the model
layer.gemm1_alpha = Parameter( layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False, requires_grad=False,
...@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_flashinfer: if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
...@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b1=layer.w13_weight_bias, b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias, b2=layer.w2_weight_bias,
topk_output=topk_output, topk_output=topk_output,
activation=activation, moe_runner_config=moe_runner_config,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
else: else:
return self.triton_kernel_moe_forward( return self.triton_kernel_moe_forward(
...@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=moe_runner_config,
) )
else: else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=moe_runner_config,
b1=layer.w13_weight_bias, b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias, b2=layer.w2_weight_bias,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
from __future__ import annotations from __future__ import annotations
import importlib import importlib.util
from typing import TYPE_CHECKING, Callable, List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -24,7 +24,7 @@ from sglang.srt.utils import ( ...@@ -24,7 +24,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
...@@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
kwargs = {}
if activation_alpha is not None:
kwargs["activation_alpha"] = activation_alpha
if swiglu_limit is not None:
kwargs["swiglu_limit"] = swiglu_limit
return self.forward( return self.forward(
x=x, x=x,
layer=layer, layer=layer,
topk_output=topk_output, topk_output=topk_output,
activation=activation, moe_runner_config=moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
**kwargs,
) )
def forward_cuda( def forward_cuda(
...@@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_triton_kernels: if self.use_triton_kernels:
if self.with_bias: if self.with_bias:
assert self.triton_kernel_moe_with_bias_forward is not None
return self.triton_kernel_moe_with_bias_forward( return self.triton_kernel_moe_with_bias_forward(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1=layer.w13_weight_bias, b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias, b2=layer.w2_weight_bias,
topk_output=topk_output, topk_output=topk_output,
activation=activation, moe_runner_config=moe_runner_config,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
w1_pcg=None, w1_pcg=None,
w2_pcg=None, w2_pcg=None,
) )
else: else:
assert self.triton_kernel_moe_forward is not None
return self.triton_kernel_moe_forward( return self.triton_kernel_moe_forward(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=moe_runner_config,
) )
else: else:
if _use_aiter: if _use_aiter:
assert not no_combine, "unsupported" assert not moe_runner_config.no_combine, "unsupported"
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
if apply_router_weight_on_input: if moe_runner_config.apply_router_weight_on_input:
assert ( assert (
topk_weights.dim() == 2 topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)" ), "`topk_weights` should be in shape (num_tokens, topk)"
...@@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids, topk_ids,
activation=( activation=(
ActivationType.Silu ActivationType.Silu
if activation == "silu" if moe_runner_config.activation == "silu"
else ActivationType.Gelu else ActivationType.Gelu
), ),
) )
...@@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1=getattr(layer, "w13_weight_bias", None), b1=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None), b2=getattr(layer, "w2_weight_bias", None),
topk_output=topk_output, topk_output=topk_output,
inplace=inplace and not no_combine, moe_runner_config=moe_runner_config,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
def forward_cpu( def forward_cpu(
...@@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported." assert (
moe_runner_config.activation == "silu"
), f"activation = {moe_runner_config.activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input: if (
use_intel_amx_backend(layer)
and not moe_runner_config.apply_router_weight_on_input
):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
...@@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer, layer,
x, x,
topk_output, topk_output,
activation=activation, moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
def forward_npu( def forward_npu(
...@@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
...@@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer, layer,
x, x,
topk_output, topk_output,
activation=activation, moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
def forward_tpu(self, *args, **kwargs) -> torch.Tensor: def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
......
...@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped ...@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
self, self,
layer: EPMoE, layer: EPMoE,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
activation: str = "silu", moe_runner_config: MoeRunnerConfig,
apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO(ch-wan): move it out of this class # TODO(ch-wan): move it out of this class
...@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale, layer.w13_input_scale,
layer.w2_input_scale, layer.w2_input_scale,
) )
if routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
return output return output
...@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
...@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): ...@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): ...@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -49,6 +49,7 @@ from sglang.srt.utils import ( ...@@ -49,6 +49,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
...@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True, use_int8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer, layer,
x, x,
topk_output: TopKOutput, topk_output: TopKOutput,
**kwargs, moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
......
...@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ...@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe import is_tbo_enabled
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator, SWATokenToKVPoolAllocator,
...@@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
"device", "device",
"disable_chunked_prefix_cache", "disable_chunked_prefix_cache",
"disable_radix_cache", "disable_radix_cache",
"enable_two_batch_overlap",
"tbo_token_distribution_threshold",
"enable_dp_lm_head", "enable_dp_lm_head",
"moe_a2a_backend",
"deepep_mode",
"enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion", "enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size", "moe_dense_tp_size",
"ep_dispatch_algorithm", "ep_dispatch_algorithm",
"deepep_config",
"ep_num_redundant_experts", "ep_num_redundant_experts",
"enable_nan_detection", "enable_nan_detection",
"flashinfer_mla_disable_ragged", "flashinfer_mla_disable_ragged",
...@@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"triton_attention_reduce_in_fp32", "triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
"weight_loader_disable_mmap", "weight_loader_disable_mmap",
"enable_triton_kernel_moe",
"enable_flashinfer_mxfp4_moe",
"enable_multimodal", "enable_multimodal",
"enable_symm_mem", "enable_symm_mem",
"quantization", "quantization",
......
...@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
...@@ -245,6 +245,9 @@ class Scheduler( ...@@ -245,6 +245,9 @@ class Scheduler(
) )
) )
# Init model config
self.model_config = ModelConfig.from_server_args(server_args)
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.idle_sleeper = None self.idle_sleeper = None
...@@ -292,6 +295,9 @@ class Scheduler( ...@@ -292,6 +295,9 @@ class Scheduler(
# Init tokenizer # Init tokenizer
self.init_tokenizer() self.init_tokenizer()
# Init moe config
self.init_moe_config()
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer: if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser( reasoning_parser = ReasoningParser(
...@@ -538,8 +544,6 @@ class Scheduler( ...@@ -538,8 +544,6 @@ class Scheduler(
def init_tokenizer(self): def init_tokenizer(self):
server_args = self.server_args server_args = self.server_args
self.model_config = ModelConfig.from_server_args(server_args)
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
...@@ -761,6 +765,10 @@ class Scheduler( ...@@ -761,6 +765,10 @@ class Scheduler(
# The prefill requests that are in the middle of kv sending # The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = [] self.disagg_prefill_inflight_queue: List[Req] = []
def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)
@DynamicGradMode() @DynamicGradMode()
def event_loop_normal(self): def event_loop_normal(self):
"""A normal scheduler loop.""" """A normal scheduler loop."""
...@@ -1823,11 +1831,6 @@ class Scheduler( ...@@ -1823,11 +1831,6 @@ class Scheduler(
disable_cuda_graph=self.server_args.disable_cuda_graph, disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=MoeA2ABackend(
self.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule, disable_overlap_schedule=self.server_args.disable_overlap_schedule,
) )
...@@ -1922,9 +1925,6 @@ class Scheduler( ...@@ -1922,9 +1925,6 @@ class Scheduler(
disable_cuda_graph: bool, disable_cuda_graph: bool,
spec_algorithm, spec_algorithm,
speculative_num_draft_tokens, speculative_num_draft_tokens,
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool, require_mlp_tp_gather: bool,
disable_overlap_schedule: bool, disable_overlap_schedule: bool,
): ):
...@@ -1972,9 +1972,6 @@ class Scheduler( ...@@ -1972,9 +1972,6 @@ class Scheduler(
is_extend_in_batch, is_extend_in_batch,
*tbo_preparer.prepare_all_gather( *tbo_preparer.prepare_all_gather(
local_batch, local_batch,
deepep_mode,
enable_deepep_moe,
enable_two_batch_overlap,
), ),
], ],
dtype=torch.int64, dtype=torch.int64,
......
...@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import ( ...@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention, initialize_dp_attention,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.quantization import ( from sglang.srt.layers.quantization import (
deep_gemm_wrapper, deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer, monkey_patch_isinstance_for_vllm_base_layer,
...@@ -219,8 +218,6 @@ class ModelRunner: ...@@ -219,8 +218,6 @@ class ModelRunner:
# TODO it is indeed not a "server args" # TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend, "use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm, "speculative_algorithm": self.spec_algorithm,
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
} }
) )
......
...@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import ( ...@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
...@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module): ...@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype) self.router = DbrxRouter(config, self.params_dtype)
self.topk = TopK(
self.top_k,
renormalize=True,
)
self.moe_runner_config = MoeRunnerConfig(inplace=True)
self.ws = nn.Parameter( self.ws = nn.Parameter(
torch.empty( torch.empty(
self.num_total_experts, self.num_total_experts,
...@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module): ...@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
hidden_states = hidden_states.view(-1, self.d_model) hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states) router_logits = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe( final_hidden_states = fused_moe(
hidden_states, hidden_states,
self.ws, self.ws,
self.w2s, self.w2s,
router_logits, topk_output,
self.top_k, self.moe_runner_config,
renormalize=True,
inplace=True,
) )
if self.tp_size > 1: if self.tp_size > 1:
...@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module): ...@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states residual = hidden_states
hidden_states = self.norm_1(hidden_states) hidden_states = self.norm_1(hidden_states)
x = self.attn( x = self.attn(
......
...@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import ( ...@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module): ...@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
w1=self.w1, w1=self.w1,
w2=self.w2, w2=self.w2,
topk_output=topk_output, topk_output=topk_output,
inplace=True, moe_runner_config=MoeRunnerConfig(inplace=True),
) )
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment