Unverified Commit 3fa62da7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[7/N] MoE Refactor: the implementation of new framework (#9269)

parent dbb1235d
...@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.token_dispatcher import (
StandardDispatchOutput,
CombineInput,
)
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
...@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.linear import set_weight_attrs from sglang.srt.layers.linear import set_weight_attrs
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
intermediate_size = extra_weight_attrs.pop("intermediate_size") self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size
)
if self.quant_config.group_size != -1: if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size scales_size13 = hidden_size // self.quant_config.group_size
w2_scales_size = ( if self.quant_config.desc_act:
intermediate_size w2_scales_size = intermediate_size_per_partition
if self.quant_config.desc_act else:
else intermediate_size_per_partition w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
)
scales_size2 = w2_scales_size // self.quant_config.group_size scales_size2 = w2_scales_size // self.quant_config.group_size
strategy = FusedMoeWeightScaleSupported.GROUP.value strategy = FusedMoeWeightScaleSupported.GROUP.value
else: else:
...@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
) )
replace_parameter(layer, "w2_scales", marlin_w2_scales) replace_parameter(layer, "w2_scales", marlin_w2_scales)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
assert ( assert (
moe_runner_config.activation == "silu" self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported." ), "Only SiLU activation is supported."
# The input must currently be float16 # The input must currently be float16
...@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, router_logits = topk_output topk_weights, topk_ids, router_logits = topk_output
return fused_marlin_moe( output = fused_marlin_moe(
x, x,
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
...@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
).to(orig_dtype) ).to(orig_dtype)
return StandardCombineInput(hidden_states=output)
...@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter ...@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe, should_use_flashinfer_trtllm_moe,
) )
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2 ...@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import (
from sglang.srt.layers.moe.topk import TopKOutput CombineInput,
StandardDispatchOutput,
)
if is_cuda(): if is_cuda():
from sgl_kernel import scaled_fp4_quant from sgl_kernel import scaled_fp4_quant
...@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
...@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
w13_weight = ModelWeightParameter( w13_weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=weight_dtype,
), ),
input_dim=2, input_dim=2,
output_dim=1, output_dim=1,
...@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
w2_weight = ModelWeightParameter( w2_weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
num_experts, hidden_size, intermediate_size, dtype=weight_dtype num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=weight_dtype,
), ),
input_dim=2, input_dim=2,
output_dim=1, output_dim=1,
...@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale # Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size) # w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
# where the first intermediate_size rows are w1, the next are w3 # where the first intermediate_size_per_partition rows are w1, the next are w3
intermediate_size = layer.w13_weight.shape[1] // 2 intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
for expert_id in range(layer.w13_weight.shape[0]): for expert_id in range(layer.w13_weight.shape[0]):
start = 0 start = 0
for shard_id in range(2): # w1 and w3 for shard_id in range(2): # w1 and w3
# Dequantize using the original scale for this shard # Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize( dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][ layer.w13_weight[expert_id][
start : start + intermediate_size, : start : start + intermediate_size_per_partition, :
], ],
layer.w13_weight_scale[expert_id][shard_id], layer.w13_weight_scale[expert_id][shard_id],
) )
# Requantize using the combined max scale # Requantize using the combined max scale
( (
layer.w13_weight[expert_id][ layer.w13_weight[expert_id][
start : start + intermediate_size, : start : start + intermediate_size_per_partition, :
], ],
_, _,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += intermediate_size start += intermediate_size_per_partition
# Update the scale parameter to be per-expert instead of per-shard # Update the scale parameter to be per-expert instead of per-shard
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
...@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ...@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale.max(), requires_grad=False layer.w2_input_scale.max(), requires_grad=False
) )
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: quant_info = TritonMoeQuantInfo(
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization per_channel_quant=False,
w1_scale=layer.w13_weight_scale, w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
return self.runner.run(dispatch_output, quant_info)
class ModelOptFp4Config(QuantizationConfig): class ModelOptFp4Config(QuantizationConfig):
"""Config class for FP4.""" """Config class for FP4."""
...@@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_cutlass_moe return self.enable_flashinfer_cutlass_moe
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
assert ( assert (
moe_runner_config.activation == "silu" self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported." ), "Only SiLU activation is supported."
moe_runner_config = self.moe_runner_config
# Check if this is a FlashInferFP4MoE layer that should handle its own forward # Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"): if hasattr(layer, "gemm1_weights_fp4_shuffled"):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward # This layer was processed with flashinfer TRTLLM - delegate to its own forward
return layer.forward(x, topk_output) return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
if self.enable_flashinfer_cutlass_moe: if self.enable_flashinfer_cutlass_moe:
assert ( assert (
...@@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_rank=layer.moe_tp_rank, tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]), tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0] )[0]
# Scale by routed_scaling_factor is fused into select_experts.
if should_use_flashinfer_cutlass_moe_fp4_allgather(): if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output output, global_output = get_local_dp_buffer(), output
get_tp_group().reduce_scatterv( get_tp_group().reduce_scatterv(
global_output, output=output, sizes=get_dp_global_num_tokens() global_output, output=output, sizes=get_dp_global_num_tokens()
) )
return output return StandardCombineInput(hidden_states=output)
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
...@@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype) ).to(x.dtype)
# Scale by routed_scaling_factor is fused into select_experts. # Scale by routed_scaling_factor is fused into select_experts.
return output
return StandardCombineInput(hidden_states=output)
...@@ -9,6 +9,8 @@ import torch ...@@ -9,6 +9,8 @@ import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs ...@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import (
from sglang.srt.layers.moe.topk import TopKOutput CombineInput,
StandardDispatchOutput,
)
def get_weight_perm(num_bits: int): def get_weight_perm(num_bits: int):
...@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.register_parameter(key, param) layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs) set_weight_attrs(param, extra_weight_attrs)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
assert ( assert (
moe_runner_config.activation == "silu" self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported." ), "Only SiLU activation is supported."
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
return fused_experts( quant_info = TritonMoeQuantInfo(
x, w13_weight=layer.w13_qweight,
layer.w13_qweight, w2_weight=layer.w2_qweight,
layer.w2_qweight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales, w13_scale=layer.w13_scales,
w2_scale=layer.w2_scales, w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None, w13_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size], block_shape=[0, layer.group_size],
) )
return self.runner.run(dispatch_output, quant_info)
@staticmethod @staticmethod
def get_weight_loader(layer, weight_loader): def get_weight_loader(layer, weight_loader):
......
...@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -59,8 +61,10 @@ if is_flashinfer_available(): ...@@ -59,8 +61,10 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import (
from sglang.srt.layers.moe.topk import TopKOutput CombineInput,
StandardDispatchOutput,
)
_is_hip = is_hip() _is_hip = is_hip()
...@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
with_bias: bool = False, with_bias: bool = False,
**extra_weight_attrs, **extra_weight_attrs,
...@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# pad the intermediate size to be a multiple of 2 * mxfp4_block # pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling # for to hold non-uniform sharded tensor as well as swizzling
intermediate_size_per_partition_after_pad = intermediate_size intermediate_size_per_partition_after_pad = intermediate_size_per_partition
if _is_sm100_supported: if _is_sm100_supported:
if self.use_flashinfer: if self.use_flashinfer:
intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 256 intermediate_size_per_partition, 256
) )
hidden_size = round_up(hidden_size, 256) hidden_size = round_up(hidden_size, 256)
else: else:
intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 64 intermediate_size_per_partition, 64
) )
elif has_triton_kernels: elif has_triton_kernels:
# TODO: this is a hack to make # TODO: this is a hack to make
# intermediate_size_per_partition_after_pad the same as the # intermediate_size_per_partition_after_pad the same as the
# per_rank_intermediate_size during weight loading # per_rank_intermediate_size during weight loading
intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition_after_pad = round_up(
intermediate_size, mxfp4_block intermediate_size_per_partition, mxfp4_block
) )
self.intermediate_size = intermediate_size_per_partition_after_pad self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
...@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
assert ( assert (
layer.w13_weight.dim() == 3 layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2 and layer.w13_weight.shape[1]
== self.intermediate_size_per_partition * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2 and layer.w13_weight.shape[2] == self.hidden_size // 2
) )
assert ( assert (
layer.w13_weight_scale.dim() == 3 layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 and layer.w13_weight_scale.shape[1]
== self.intermediate_size_per_partition * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
) )
assert ( assert (
layer.w2_weight.dim() == 3 layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2 and layer.w2_weight.shape[2]
== self.intermediate_size_per_partition // 2
) )
assert ( assert (
layer.w2_weight_scale.dim() == 3 layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2] and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size == self.intermediate_size_per_partition // sf_block_size
) )
assert ( assert (
layer.w13_weight_bias.dim() == 2 layer.w13_weight_bias.dim() == 2
and layer.w13_weight_bias.shape[0] == self.num_experts and layer.w13_weight_bias.shape[0] == self.num_experts
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2 and layer.w13_weight_bias.shape[1]
== self.intermediate_size_per_partition * 2
) )
assert ( assert (
layer.w2_weight_bias.dim() == 2 layer.w2_weight_bias.dim() == 2
...@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch.stack(gemm1_scales_mxfp4_shuffled) torch.stack(gemm1_scales_mxfp4_shuffled)
.reshape( .reshape(
self.num_experts, self.num_experts,
2 * self.intermediate_size, 2 * self.intermediate_size_per_partition,
self.hidden_size // sf_block_size, self.hidden_size // sf_block_size,
) )
.view(torch.float8_e4m3fn) .view(torch.float8_e4m3fn)
...@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
.reshape( .reshape(
self.num_experts, self.num_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size // sf_block_size, self.intermediate_size_per_partition // sf_block_size,
) )
.view(torch.float8_e4m3fn) .view(torch.float8_e4m3fn)
) )
...@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim return tile_tokens_dim
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.topk import TopKOutputChecker
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if self.use_flashinfer: if self.use_flashinfer:
# When bf16 mode is enabled, we don't need to quantize the input, # When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations, # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
...@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
top_k, top_k,
None, # n_group # TODO: support n_group None, # n_group # TODO: support n_group
None, # topk_group # TODO: support topk_group None, # topk_group # TODO: support topk_group
self.intermediate_size, # padded to multiple of 256 self.intermediate_size_per_partition, # padded to multiple of 256
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
layer.num_local_experts, # local num experts layer.num_local_experts, # local num experts
None, None,
...@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
1, # routing_method_type, renormalize 1, # routing_method_type, renormalize
True, # do finalize True, # do finalize
)[0] )[0]
return trtllm_gen_output return StandardCombineInput(hidden_states=trtllm_gen_output)
if self.use_triton_kernels: if self.use_triton_kernels:
assert ( assert (
layer.moe_ep_size == 1 layer.moe_ep_size == 1
), "Expert parallel is not supported when using triton kernels" ), "Expert parallel is not supported when using triton kernels"
if self.with_bias: if self.with_bias:
return self.triton_kernel_moe_with_bias_forward( output = self.triton_kernel_moe_with_bias_forward(
hidden_states=x, hidden_states=x,
w1=self.w13_weight_triton_tensor, w1=self.w13_weight_triton_tensor,
w1_pcg=self.w13_precision_config, w1_pcg=self.w13_precision_config,
...@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
moe_runner_config=moe_runner_config, moe_runner_config=moe_runner_config,
) )
else: else:
return self.triton_kernel_moe_forward( output = self.triton_kernel_moe_forward(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=moe_runner_config, moe_runner_config=moe_runner_config,
) )
return StandardCombineInput(hidden_states=output)
else: else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
return fused_experts( w2_weight=layer.w2_weight,
hidden_states=x, w13_weight_bias=layer.w13_weight_bias,
w1=layer.w13_weight, w2_weight_bias=layer.w2_weight_bias,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
) )
return self.runner.run(dispatch_output, quant_info)
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
...@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): ...@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
return w, mx_scales return w, mx_scales
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data) w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data) w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
...@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): ...@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False) layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig, from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
) -> torch.Tensor:
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
if _is_hip: if _is_hip:
topk_weights = topk_weights.to( topk_weights = topk_weights.to(
torch.float32 torch.float32
) # aiter's moe_sorting requires topk_weights to be FP32 ) # aiter's moe_sorting requires topk_weights to be FP32
return fused_moe( output = fused_moe(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): ...@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
activation=( activation=(
ActivationType.Silu ActivationType.Silu
if moe_runner_config.activation == "silu" if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu else ActivationType.Gelu
), ),
doweight_stage1=False, doweight_stage1=False,
) )
return StandardCombineInput(hidden_states=output)
...@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk ...@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
from aiter.fused_moe import fused_moe from aiter.fused_moe import fused_moe
from aiter.utility.fp4_utils import e8m0_shuffle from aiter.utility.fp4_utils import e8m0_shuffle
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"] __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
...@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"] ...@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
OCP_MX_BLOCK_SIZE = 32 OCP_MX_BLOCK_SIZE = 32
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization import QuarkConfig
class QuarkMoEMethod: class QuarkMoEMethod(FusedMoEMethodBase):
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase def __init__(self, quant_config: QuarkConfig):
self.quant_config = quant_config
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
module: torch.nn.Module, module: torch.nn.Module,
layer_name: str, layer_name: str,
) -> "QuarkMoEMethod": ) -> "QuarkMoEMethod":
...@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): ...@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False) # layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
return fused_moe( output = fused_moe(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): ...@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
), ),
doweight_stage1=False, doweight_stage1=False,
) )
return StandardCombineInput(hidden_states=output)
...@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter ...@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
LinearMethodBase, LinearMethodBase,
...@@ -24,8 +26,10 @@ from sglang.srt.utils import ( ...@@ -24,8 +26,10 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import (
from sglang.srt.layers.moe.topk import TopKOutput CombineInput,
StandardDispatchOutput,
)
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
...@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
with_bias: bool = False, with_bias: bool = False,
**extra_weight_attrs, **extra_weight_attrs,
...@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.with_bias = with_bias self.with_bias = with_bias
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size w13_weight_n, w13_weight_k = 2 * intermediate_size_per_partition, hidden_size
if self.use_triton_kernels: if self.use_triton_kernels:
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
...@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if self.with_bias: if self.with_bias:
w13_weight_bias = torch.nn.Parameter( w13_weight_bias = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32), torch.empty(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_weight_bias", w13_weight_bias) layer.register_parameter("w13_weight_bias", w13_weight_bias)
...@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# down_proj (row parallel) # down_proj (row parallel)
w2_weight_n, w2_weight_k = ( w2_weight_n, w2_weight_k = (
hidden_size, hidden_size,
intermediate_size, intermediate_size_per_partition,
) )
if self.use_triton_kernels: if self.use_triton_kernels:
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
...@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
return return
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
return self.forward( return self.forward(
x=x,
layer=layer, layer=layer,
topk_output=topk_output, dispatch_output=dispatch_output,
moe_runner_config=moe_runner_config,
) )
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if self.use_triton_kernels: if self.use_triton_kernels:
if self.with_bias: if self.with_bias:
assert self.triton_kernel_moe_with_bias_forward is not None assert self.triton_kernel_moe_with_bias_forward is not None
return self.triton_kernel_moe_with_bias_forward( output = self.triton_kernel_moe_with_bias_forward(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) )
else: else:
assert self.triton_kernel_moe_forward is not None assert self.triton_kernel_moe_forward is not None
return self.triton_kernel_moe_forward( output = self.triton_kernel_moe_forward(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=moe_runner_config, moe_runner_config=moe_runner_config,
) )
return StandardCombineInput(hidden_states=output)
else: else:
if _use_aiter: if _use_aiter:
assert not moe_runner_config.no_combine, "unsupported" assert not moe_runner_config.no_combine, "unsupported"
...@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights = torch.ones_like( topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32 topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32) ) # topk_weights must be FP32 (float32)
return fused_moe( output = fused_moe(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else ActivationType.Gelu else ActivationType.Gelu
), ),
) )
return StandardCombineInput(hidden_states=output)
else: else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts,
)
return fused_experts( quant_info = TritonMoeQuantInfo(
hidden_states=x, w13_weight=layer.w13_weight,
w1=layer.w13_weight, w2_weight=layer.w2_weight,
w2=layer.w2_weight, b13=getattr(layer, "w13_weight_bias", None),
b1=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None), b2=getattr(layer, "w2_weight_bias", None),
topk_output=topk_output,
moe_runner_config=moe_runner_config,
) )
return self.runner.run(dispatch_output, quant_info)
def forward_cpu( def forward_cpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
assert ( assert (
moe_runner_config.activation == "silu" moe_runner_config.activation == "silu"
), f"activation = {moe_runner_config.activation} is not supported." ), f"activation = {moe_runner_config.activation} is not supported."
...@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
moe_runner_config.apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( output = torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -348,33 +366,39 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -348,33 +366,39 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
None, # a2_scale None, # a2_scale
True, # is_vnni True, # is_vnni
) )
return StandardCombineInput(hidden_states=output)
else: else:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
return moe_forward_native( output = moe_forward_native(
layer, layer,
x, x,
topk_output, topk_output,
moe_runner_config, moe_runner_config,
) )
return StandardCombineInput(hidden_states=output)
def forward_npu( def forward_npu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
return moe_forward_native( output = moe_forward_native(
layer, layer,
x, x,
topk_output, topk_output,
moe_runner_config, self.moe_runner_config,
) )
return StandardCombineInput(hidden_states=output)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor: def forward_tpu(self, *args, **kwargs) -> CombineInput:
raise NotImplementedError("The TPU backend currently does not support MoE.") raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cpu forward_native = forward_cpu
...@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter ...@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
...@@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs ...@@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer: EPMoE, layer: EPMoE,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
...@@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
intermediate_size * 2, intermediate_size_per_partition * 2,
hidden_size // 2, hidden_size // 2,
dtype=torch.int8, dtype=torch.int8,
), ),
...@@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
torch.empty( torch.empty(
num_experts, num_experts,
hidden_size, hidden_size,
intermediate_size // 2, intermediate_size_per_partition // 2,
dtype=torch.int8, dtype=torch.int8,
), ),
requires_grad=False, requires_grad=False,
...@@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts, num_experts,
2 * intermediate_size, 2 * intermediate_size_per_partition,
hidden_size // self.quant_config.group_size, hidden_size // self.quant_config.group_size,
dtype=torch.float32, dtype=torch.float32,
), ),
...@@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
torch.zeros( torch.zeros(
num_experts, num_experts,
hidden_size, hidden_size,
intermediate_size // self.quant_config.group_size, intermediate_size_per_partition // self.quant_config.group_size,
dtype=torch.float32, dtype=torch.float32,
), ),
requires_grad=False, requires_grad=False,
...@@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
) )
self.c_strides1 = torch.full( self.c_strides1 = torch.full(
(num_experts, 3), (num_experts, 3),
2 * intermediate_size, 2 * intermediate_size_per_partition,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
self.a_strides2 = torch.full( self.a_strides2 = torch.full(
(num_experts, 3), (num_experts, 3),
intermediate_size, intermediate_size_per_partition,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
...@@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
) )
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False) layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply( def apply(
self, self,
layer: EPMoE, layer: EPMoE,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: StandardTopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
# TODO(ch-wan): move it out of this class x = dispatch_output.hidden_states
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids local_topk_ids = topk_ids
...@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale, layer.w13_input_scale,
layer.w2_input_scale, layer.w2_input_scale,
) )
if moe_runner_config.routed_scaling_factor is not None: if self.moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor output *= self.moe_runner_config.routed_scaling_factor
return output return StandardCombineInput(hidden_states=output)
...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional ...@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -26,8 +27,11 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -26,8 +27,11 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
...@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): ...@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
...@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): ...@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=fp8_dtype,
), ),
requires_grad=False, requires_grad=False,
) )
...@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): ...@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype), torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=fp8_dtype,
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), torch.ones(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False, requires_grad=False,
) )
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
...@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): ...@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale.data, requires_grad=False layer.w2_weight_scale.data, requires_grad=False
) )
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: StandardTopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
return fused_experts( quant_info = TritonMoeQuantInfo(
x, w13_weight=layer.w13_weight,
layer.w13_weight, w2_weight=layer.w2_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
w1_scale=(layer.w13_weight_scale), w13_scale=layer.w13_weight_scale,
w2_scale=(layer.w2_weight_scale), w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
return self.runner.run(dispatch_output, quant_info)
...@@ -24,6 +24,8 @@ from sglang.srt.distributed import ( ...@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
...@@ -49,8 +51,10 @@ from sglang.srt.utils import ( ...@@ -49,8 +51,10 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import (
from sglang.srt.layers.moe.topk import TopKOutput CombineInput,
StandardDispatchOutput,
)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
...@@ -417,7 +421,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -417,7 +421,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
...@@ -428,7 +432,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -428,7 +432,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8,
), ),
requires_grad=False, requires_grad=False,
) )
...@@ -436,14 +443,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -436,14 +443,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8,
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), torch.ones(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False, requires_grad=False,
) )
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
...@@ -483,23 +497,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -483,23 +497,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale.data, requires_grad=False layer.w2_weight_scale.data, requires_grad=False
) )
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
moe_runner_config.apply_router_weight_on_input, topk_weights, x self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( output = torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -515,20 +536,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): ...@@ -515,20 +536,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale, # a2_scale layer.w2_input_scale, # a2_scale
True, # is_vnni True, # is_vnni
) )
return StandardCombineInput(hidden_states=output)
return fused_experts( quant_info = TritonMoeQuantInfo(
x, w13_weight=layer.w13_weight,
layer.w13_weight, w2_weight=layer.w2_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
use_int8_w8a8=True, use_int8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
w1_scale=(layer.w13_weight_scale), w13_scale=layer.w13_weight_scale,
w2_scale=(layer.w2_weight_scale), w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
return self.runner.run(dispatch_output, quant_info)
class NPU_W8A8LinearMethodImpl: class NPU_W8A8LinearMethodImpl:
...@@ -900,7 +920,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -900,7 +920,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
...@@ -914,21 +934,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -914,21 +934,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
# weight # weight
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8,
), ),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8,
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# scale # scale
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
...@@ -941,7 +971,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -941,7 +971,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# offset # offset
w13_weight_offset = torch.nn.Parameter( w13_weight_offset = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w13_weight_offset", w13_weight_offset) layer.register_parameter("w13_weight_offset", w13_weight_offset)
...@@ -973,18 +1005,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -973,18 +1005,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
) )
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply( def apply(
self, self,
layer, layer,
x, dispatch_output: StandardDispatchOutput,
topk_output: TopKOutput, ) -> CombineInput:
moe_runner_config: MoeRunnerConfig, from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
) -> torch.Tensor:
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
return npu_fused_experts( output = npu_fused_experts(
hidden_states=x, hidden_states=x,
w13=layer.w13_weight, w13=layer.w13_weight,
w13_scale=layer.w13_weight_scale, w13_scale=layer.w13_weight_scale,
...@@ -994,3 +1033,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): ...@@ -994,3 +1033,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids, topk_ids=topk_ids,
top_k=topk_ids.shape[1], top_k=topk_ids.shape[1],
) )
return StandardCombineInput(hidden_states=output)
...@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ...@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe import is_tbo_enabled
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator, SWATokenToKVPoolAllocator,
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
from __future__ import annotations
from typing import TYPE_CHECKING
from torch import nn from torch import nn
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
from sglang.srt.model_loader.utils import ( from sglang.srt.model_loader.utils import (
get_architecture_class_name, get_architecture_class_name,
get_model_architecture, get_model_architecture,
) )
if TYPE_CHECKING:
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
def get_model( def get_model(
*, *,
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
from __future__ import annotations
# ruff: noqa: SIM117 # ruff: noqa: SIM117
import collections import collections
import concurrent import concurrent
...@@ -14,7 +16,17 @@ import time ...@@ -14,7 +16,17 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
cast,
)
import huggingface_hub import huggingface_hub
import numpy as np import numpy as np
...@@ -26,9 +38,7 @@ from tqdm.auto import tqdm ...@@ -26,9 +38,7 @@ from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.connector import ( from sglang.srt.connector import (
ConnectorType, ConnectorType,
create_remote_connector, create_remote_connector,
...@@ -39,7 +49,6 @@ from sglang.srt.distributed import ( ...@@ -39,7 +49,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import ( from sglang.srt.model_loader.utils import (
get_model_architecture, get_model_architecture,
post_load_weights, post_load_weights,
...@@ -70,6 +79,11 @@ from sglang.srt.utils import ( ...@@ -70,6 +79,11 @@ from sglang.srt.utils import (
set_weight_attrs, set_weight_attrs,
) )
if TYPE_CHECKING:
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
_is_npu = is_npu() _is_npu = is_npu()
......
...@@ -9,6 +9,7 @@ from transformers import AutoConfig ...@@ -9,6 +9,7 @@ from transformers import AutoConfig
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py # Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
...@@ -152,14 +153,32 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -152,14 +153,32 @@ def run_test(tp_size, batch_size, model_config, check=False):
problem_sizes2, problem_sizes2,
) )
topk_output = StandardTopKOutput(
topk_weights=topk_weights,
topk_ids=topk_ids,
router_logits=torch.randn(
(batch_size, topk), device=topk_weights.device, dtype=dtype
),
)
moe_runner_config = MoeRunnerConfig(
num_experts=E,
topk=topk,
hidden_size=H,
shard_intermediate_size=I,
dtype=dtype,
block_shape=block_shape,
activation="silu",
inplace=False,
)
# Note: Triton expects non-transposed weights # Note: Triton expects non-transposed weights
moe_config = MoeRunnerConfig(inplace=False)
triton_lambda = lambda: fused_experts( triton_lambda = lambda: fused_experts(
x, x,
w1, w1,
w2, w2,
(topk_weights, topk_ids, "dummy"), topk_output,
moe_config, moe_runner_config,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
...@@ -224,8 +243,8 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -224,8 +243,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
x, x,
w1, # Original shape w1, # Original shape
w2, # Original shape w2, # Original shape
(topk_weights, topk_ids, "dummy"), topk_output,
moe_config, moe_runner_config,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
......
import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -49,6 +50,42 @@ class TestMLADeepseekV3(CustomTestCase): ...@@ -49,6 +50,42 @@ class TestMLADeepseekV3(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.62) self.assertGreater(metrics["accuracy"], 0.62)
class TestMLADeepseekV3DisableFusedFunc(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_CI_DISABLE_MOE_FUSED_FUNC"] = "1"
cls.model = "lmsys/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code", "--chunked-prefill-size", "256"]
if is_cuda():
other_args.extend(["--cuda-graph-max-bs", "2"])
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
@unittest.skipIf(is_hip(), "FA is not available.") @unittest.skipIf(is_hip(), "FA is not available.")
class TestMLADeepseekV3Fa3Fp8Kvcache(CustomTestCase): class TestMLADeepseekV3Fa3Fp8Kvcache(CustomTestCase):
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment