Unverified Commit 3fa62da7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[7/N] MoE Refactor: the implementation of new framework (#9269)

parent dbb1235d
......@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
StandardDispatchOutput,
CombineInput,
)
from sglang.srt.utils import is_cuda
......@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
from sglang.srt.layers.linear import set_weight_attrs
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
intermediate_size = extra_weight_attrs.pop("intermediate_size")
self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size
)
self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
w2_scales_size = (
intermediate_size
if self.quant_config.desc_act
else intermediate_size_per_partition
)
if self.quant_config.desc_act:
w2_scales_size = intermediate_size_per_partition
else:
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
scales_size2 = w2_scales_size // self.quant_config.group_size
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
......@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
# Delay the import to avoid circular dependency
assert (
moe_runner_config.activation == "silu"
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16
......@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, router_logits = topk_output
return fused_marlin_moe(
output = fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
......@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_bits=self.quant_config.weight_bits,
is_k_full=self.is_k_full,
).to(orig_dtype)
return StandardCombineInput(hidden_states=output)
......@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
should_use_flashinfer_cutlass_moe_fp4_allgather,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
......@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
if is_cuda():
from sgl_kernel import scaled_fp4_quant
......@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
......@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=weight_dtype,
),
input_dim=2,
output_dim=1,
......@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=weight_dtype,
),
input_dim=2,
output_dim=1,
......@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size = layer.w13_weight.shape[1] // 2
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
# where the first intermediate_size_per_partition rows are w1, the next are w3
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
for expert_id in range(layer.w13_weight.shape[0]):
start = 0
for shard_id in range(2): # w1 and w3
# Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
start : start + intermediate_size_per_partition, :
],
layer.w13_weight_scale[expert_id][shard_id],
)
# Requantize using the combined max scale
(
layer.w13_weight[expert_id][
start : start + intermediate_size, :
start : start + intermediate_size_per_partition, :
],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += intermediate_size
start += intermediate_size_per_partition
# Update the scale parameter to be per-expert instead of per-shard
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
......@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale.max(), requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale,
per_channel_quant=False,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return self.runner.run(dispatch_output, quant_info)
class ModelOptFp4Config(QuantizationConfig):
"""Config class for FP4."""
......@@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return self.enable_flashinfer_cutlass_moe
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
assert (
moe_runner_config.activation == "silu"
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
moe_runner_config = self.moe_runner_config
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
return layer.forward(x, topk_output)
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
if self.enable_flashinfer_cutlass_moe:
assert (
......@@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0]
# Scale by routed_scaling_factor is fused into select_experts.
if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output
get_tp_group().reduce_scatterv(
global_output, output=output, sizes=get_dp_global_num_tokens()
)
return output
return StandardCombineInput(hidden_states=output)
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
......@@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)
# Scale by routed_scaling_factor is fused into select_experts.
return output
return StandardCombineInput(hidden_states=output)
......@@ -9,6 +9,8 @@ import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
......@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
def get_weight_perm(num_bits: int):
......@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
assert (
moe_runner_config.activation == "silu"
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_qweight,
w2_weight=layer.w2_qweight,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales,
w13_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w13_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
)
return self.runner.run(dispatch_output, quant_info)
@staticmethod
def get_weight_loader(layer, weight_loader):
......
......@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
......@@ -59,8 +61,10 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
_is_hip = is_hip()
......@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
with_bias: bool = False,
**extra_weight_attrs,
......@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
intermediate_size_per_partition_after_pad = intermediate_size
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
if _is_sm100_supported:
if self.use_flashinfer:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 256
intermediate_size_per_partition, 256
)
hidden_size = round_up(hidden_size, 256)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 64
intermediate_size_per_partition, 64
)
elif has_triton_kernels:
# TODO: this is a hack to make
# intermediate_size_per_partition_after_pad the same as the
# per_rank_intermediate_size during weight loading
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, mxfp4_block
intermediate_size_per_partition, mxfp4_block
)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
......@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
assert (
layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[1]
== self.intermediate_size_per_partition * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2
)
assert (
layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
and layer.w13_weight_scale.shape[1]
== self.intermediate_size_per_partition * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
)
assert (
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2
and layer.w2_weight.shape[2]
== self.intermediate_size_per_partition // 2
)
assert (
layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size
== self.intermediate_size_per_partition // sf_block_size
)
assert (
layer.w13_weight_bias.dim() == 2
and layer.w13_weight_bias.shape[0] == self.num_experts
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
and layer.w13_weight_bias.shape[1]
== self.intermediate_size_per_partition * 2
)
assert (
layer.w2_weight_bias.dim() == 2
......@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch.stack(gemm1_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
2 * self.intermediate_size,
2 * self.intermediate_size_per_partition,
self.hidden_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
......@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
.reshape(
self.num_experts,
self.hidden_size,
self.intermediate_size // sf_block_size,
self.intermediate_size_per_partition // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
......@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
from sglang.srt.layers.moe.topk import TopKOutputChecker
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if self.use_flashinfer:
# When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
......@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
top_k,
None, # n_group # TODO: support n_group
None, # topk_group # TODO: support topk_group
self.intermediate_size, # padded to multiple of 256
self.intermediate_size_per_partition, # padded to multiple of 256
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
layer.num_local_experts, # local num experts
None,
......@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
1, # routing_method_type, renormalize
True, # do finalize
)[0]
return trtllm_gen_output
return StandardCombineInput(hidden_states=trtllm_gen_output)
if self.use_triton_kernels:
assert (
layer.moe_ep_size == 1
), "Expert parallel is not supported when using triton kernels"
if self.with_bias:
return self.triton_kernel_moe_with_bias_forward(
output = self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w1_pcg=self.w13_precision_config,
......@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
moe_runner_config=moe_runner_config,
)
else:
return self.triton_kernel_moe_forward(
output = self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
w13_weight_bias=layer.w13_weight_bias,
w2_weight_bias=layer.w2_weight_bias,
)
return self.runner.run(dispatch_output, quant_info)
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
......@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
return w, mx_scales
def process_weights_after_loading(self, layer: Module) -> None:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
......@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
if _is_hip:
topk_weights = topk_weights.to(
torch.float32
) # aiter's moe_sorting requires topk_weights to be FP32
return fused_moe(
output = fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)
return StandardCombineInput(hidden_states=output)
......@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
from aiter.fused_moe import fused_moe
from aiter.utility.fp4_utils import e8m0_shuffle
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
logger = logging.getLogger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
......@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
OCP_MX_BLOCK_SIZE = 32
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
class QuarkMoEMethod:
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
from sglang.srt.layers.quantization import QuarkConfig
class QuarkMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: QuarkConfig):
self.quant_config = quant_config
@staticmethod
def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
module: torch.nn.Module,
layer_name: str,
) -> "QuarkMoEMethod":
......@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
topk_weights, topk_ids, _ = topk_output
return fused_moe(
output = fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
),
doweight_stage1=False,
)
return StandardCombineInput(hidden_states=output)
......@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
......@@ -24,8 +26,10 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
......@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
with_bias: bool = False,
**extra_weight_attrs,
......@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.with_bias = with_bias
# Fused gate_up_proj (column parallel)
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
w13_weight_n, w13_weight_k = 2 * intermediate_size_per_partition, hidden_size
if self.use_triton_kernels:
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
w13_weight = torch.nn.Parameter(
......@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if self.with_bias:
w13_weight_bias = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_bias", w13_weight_bias)
......@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# down_proj (row parallel)
w2_weight_n, w2_weight_k = (
hidden_size,
intermediate_size,
intermediate_size_per_partition,
)
if self.use_triton_kernels:
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
......@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
return
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
return self.forward(
x=x,
layer=layer,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
dispatch_output=dispatch_output,
)
def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if self.use_triton_kernels:
if self.with_bias:
assert self.triton_kernel_moe_with_bias_forward is not None
return self.triton_kernel_moe_with_bias_forward(
output = self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
else:
assert self.triton_kernel_moe_forward is not None
return self.triton_kernel_moe_forward(
output = self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
else:
if _use_aiter:
assert not moe_runner_config.no_combine, "unsupported"
......@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32)
return fused_moe(
output = fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else ActivationType.Gelu
),
)
return StandardCombineInput(hidden_states=output)
else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts,
)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
b1=getattr(layer, "w13_weight_bias", None),
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
b13=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None),
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
return self.runner.run(dispatch_output, quant_info)
def forward_cpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
assert (
moe_runner_config.activation == "silu"
), f"activation = {moe_runner_config.activation} is not supported."
......@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x, topk_weights = apply_topk_weights_cpu(
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
output = torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -348,33 +366,39 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
None, # a2_scale
True, # is_vnni
)
return StandardCombineInput(hidden_states=output)
else:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
return moe_forward_native(
output = moe_forward_native(
layer,
x,
topk_output,
moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
def forward_npu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
return moe_forward_native(
output = moe_forward_native(
layer,
x,
topk_output,
moe_runner_config,
self.moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
def forward_tpu(self, *args, **kwargs) -> CombineInput:
raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cpu
......@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
......@@ -22,7 +23,10 @@ from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -133,7 +137,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer: EPMoE,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
......@@ -145,7 +149,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size * 2,
intermediate_size_per_partition * 2,
hidden_size // 2,
dtype=torch.int8,
),
......@@ -159,7 +163,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
torch.empty(
num_experts,
hidden_size,
intermediate_size // 2,
intermediate_size_per_partition // 2,
dtype=torch.int8,
),
requires_grad=False,
......@@ -173,7 +177,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size,
2 * intermediate_size_per_partition,
hidden_size // self.quant_config.group_size,
dtype=torch.float32,
),
......@@ -186,7 +190,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
torch.zeros(
num_experts,
hidden_size,
intermediate_size // self.quant_config.group_size,
intermediate_size_per_partition // self.quant_config.group_size,
dtype=torch.float32,
),
requires_grad=False,
......@@ -220,13 +224,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
)
self.c_strides1 = torch.full(
(num_experts, 3),
2 * intermediate_size,
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.a_strides2 = torch.full(
(num_experts, 3),
intermediate_size,
intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
......@@ -282,16 +286,21 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
)
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: EPMoE,
x: torch.Tensor,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
# TODO(ch-wan): move it out of this class
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
......@@ -328,6 +337,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale,
layer.w2_input_scale,
)
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return StandardCombineInput(hidden_states=output)
......@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
......@@ -26,8 +27,11 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
_is_fp8_fnuz = is_fp8_fnuz()
......@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
......@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=fp8_dtype,
),
requires_grad=False,
)
......@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=fp8_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
torch.ones(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
......@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale.data, requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return self.runner.run(dispatch_output, quant_info)
......@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
......@@ -49,8 +51,10 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
......@@ -417,7 +421,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
......@@ -428,7 +432,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8,
),
requires_grad=False,
)
......@@ -436,14 +443,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
torch.ones(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
......@@ -483,23 +497,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale.data, requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
moe_runner_config.apply_router_weight_on_input, topk_weights, x
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
output = torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
......@@ -515,20 +536,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale, # a2_scale
True, # is_vnni
)
return StandardCombineInput(hidden_states=output)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_int8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return self.runner.run(dispatch_output, quant_info)
class NPU_W8A8LinearMethodImpl:
......@@ -900,7 +920,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
......@@ -914,21 +934,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
# weight
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# scale
w13_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
......@@ -941,7 +971,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# offset
w13_weight_offset = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
),
requires_grad=False,
)
layer.register_parameter("w13_weight_offset", w13_weight_offset)
......@@ -973,18 +1005,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer,
x,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
return npu_fused_experts(
output = npu_fused_experts(
hidden_states=x,
w13=layer.w13_weight,
w13_scale=layer.w13_weight_scale,
......@@ -994,3 +1033,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
top_k=topk_ids.shape[1],
)
return StandardCombineInput(hidden_states=output)
......@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe import is_tbo_enabled
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
from __future__ import annotations
from typing import TYPE_CHECKING
from torch import nn
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
from sglang.srt.model_loader.utils import (
get_architecture_class_name,
get_model_architecture,
)
if TYPE_CHECKING:
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
def get_model(
*,
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
from __future__ import annotations
# ruff: noqa: SIM117
import collections
import concurrent
......@@ -14,7 +16,17 @@ import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
cast,
)
import huggingface_hub
import numpy as np
......@@ -26,9 +38,7 @@ from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.connector import (
ConnectorType,
create_remote_connector,
......@@ -39,7 +49,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import (
get_model_architecture,
post_load_weights,
......@@ -70,6 +79,11 @@ from sglang.srt.utils import (
set_weight_attrs,
)
if TYPE_CHECKING:
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
_is_npu = is_npu()
......
......@@ -9,6 +9,7 @@ from transformers import AutoConfig
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
......@@ -152,14 +153,32 @@ def run_test(tp_size, batch_size, model_config, check=False):
problem_sizes2,
)
topk_output = StandardTopKOutput(
topk_weights=topk_weights,
topk_ids=topk_ids,
router_logits=torch.randn(
(batch_size, topk), device=topk_weights.device, dtype=dtype
),
)
moe_runner_config = MoeRunnerConfig(
num_experts=E,
topk=topk,
hidden_size=H,
shard_intermediate_size=I,
dtype=dtype,
block_shape=block_shape,
activation="silu",
inplace=False,
)
# Note: Triton expects non-transposed weights
moe_config = MoeRunnerConfig(inplace=False)
triton_lambda = lambda: fused_experts(
x,
w1,
w2,
(topk_weights, topk_ids, "dummy"),
moe_config,
topk_output,
moe_runner_config,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
......@@ -224,8 +243,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
x,
w1, # Original shape
w2, # Original shape
(topk_weights, topk_ids, "dummy"),
moe_config,
topk_output,
moe_runner_config,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
......
import os
import unittest
from types import SimpleNamespace
......@@ -49,6 +50,42 @@ class TestMLADeepseekV3(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.62)
class TestMLADeepseekV3DisableFusedFunc(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_CI_DISABLE_MOE_FUSED_FUNC"] = "1"
cls.model = "lmsys/sglang-ci-dsv3-test"
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = ["--trust-remote-code", "--chunked-prefill-size", "256"]
if is_cuda():
other_args.extend(["--cuda-graph-max-bs", "2"])
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
@unittest.skipIf(is_hip(), "FA is not available.")
class TestMLADeepseekV3Fa3Fp8Kvcache(CustomTestCase):
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment