Unverified Commit 62eff37b authored by Jonah Bernard's avatar Jonah Bernard Committed by GitHub
Browse files

Refactor Triton-kernel MoE runner integration (#11795)

parent 47e12e08
...@@ -172,7 +172,7 @@ class FusedMoE(torch.nn.Module): ...@@ -172,7 +172,7 @@ class FusedMoE(torch.nn.Module):
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.use_presharded_weights = use_presharded_weights self.use_presharded_weights = use_presharded_weights
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
self.quant_config = quant_config self.quant_config = quant_config
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4() self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
......
...@@ -47,7 +47,7 @@ def triton_kernel_moe_forward( ...@@ -47,7 +47,7 @@ def triton_kernel_moe_forward(
from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output) assert TopKOutputChecker.format_is_triton_kernels(topk_output)
routing_data, gather_idx, scatter_idx = topk_output routing_data, gather_idx, scatter_idx = topk_output
...@@ -172,6 +172,7 @@ def triton_kernel_moe_with_bias_forward( ...@@ -172,6 +172,7 @@ def triton_kernel_moe_with_bias_forward(
b2: torch.Tensor, b2: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig, moe_runner_config: MoeRunnerConfig,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -184,7 +185,7 @@ def triton_kernel_moe_with_bias_forward( ...@@ -184,7 +185,7 @@ def triton_kernel_moe_with_bias_forward(
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output) assert TopKOutputChecker.format_is_triton_kernels(topk_output)
routing_data, gather_idx, scatter_idx = topk_output routing_data, gather_idx, scatter_idx = topk_output
...@@ -201,6 +202,7 @@ def triton_kernel_moe_with_bias_forward( ...@@ -201,6 +202,7 @@ def triton_kernel_moe_with_bias_forward(
scatter_indx=scatter_idx, scatter_indx=scatter_idx,
inplace=False, # triton kernel doesn't support inplace inplace=False, # triton kernel doesn't support inplace
activation=moe_runner_config.activation, activation=moe_runner_config.activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -228,6 +230,7 @@ def triton_kernel_fused_experts_with_bias( ...@@ -228,6 +230,7 @@ def triton_kernel_fused_experts_with_bias(
scatter_indx: ScatterIndx, scatter_indx: ScatterIndx,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -296,7 +299,7 @@ def triton_kernel_fused_experts_with_bias( ...@@ -296,7 +299,7 @@ def triton_kernel_fused_experts_with_bias(
routing_data, routing_data,
gather_indx=gather_indx, gather_indx=gather_indx,
precision_config=w1_pcg, precision_config=w1_pcg,
gammas=None, gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
fused_activation=act, fused_activation=act,
) )
...@@ -307,5 +310,5 @@ def triton_kernel_fused_experts_with_bias( ...@@ -307,5 +310,5 @@ def triton_kernel_fused_experts_with_bias(
routing_data, routing_data,
scatter_indx=scatter_indx, scatter_indx=scatter_indx,
precision_config=w2_pcg, precision_config=w2_pcg,
gammas=routing_data.gate_scal, gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
) )
...@@ -11,6 +11,7 @@ from sglang.srt.layers.moe.moe_runner.base import ( ...@@ -11,6 +11,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
) )
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsRunnerCore
from sglang.srt.layers.moe.utils import get_moe_a2a_backend from sglang.srt.layers.moe.utils import get_moe_a2a_backend
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -31,6 +32,8 @@ class MoeRunner: ...@@ -31,6 +32,8 @@ class MoeRunner:
if runner_backend.is_triton(): if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config) self.runner_core = TritonRunnerCore(config)
elif runner_backend.is_triton_kernels():
self.runner_core = TritonKernelsRunnerCore(config)
elif runner_backend.is_deep_gemm(): elif runner_backend.is_deep_gemm():
self.runner_core = DeepGemmRunnerCore(config) self.runner_core = DeepGemmRunnerCore(config)
else: else:
......
"""Triton kernels MoE runner backend skeleton."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
MoeRunnerCore,
RunnerInput,
RunnerOutput,
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
# ---------------------------------------------------------------------------
# Runner IO dataclasses
# ---------------------------------------------------------------------------
@dataclass
class TritonKernelsRunnerInput(RunnerInput):
"""Input bundle passed to the triton-kernels runner core."""
hidden_states: torch.Tensor
routing_data: "RoutingData"
gather_indx: "GatherIndx"
scatter_indx: "ScatterIndx"
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON_KERNELS
@dataclass
class TritonKernelsRunnerOutput(RunnerOutput):
"""Output bundle returned from the triton-kernels runner core."""
hidden_states: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON_KERNELS
@dataclass
class TritonKernelsQuantInfo(MoeQuantInfo):
"""Quantization payload consumed by the triton-kernels backend."""
w13_weight: torch.Tensor
w2_weight: torch.Tensor
w13_bias: Optional[torch.Tensor] = None
w2_bias: Optional[torch.Tensor] = None
w13_precision_config: Optional[PrecisionConfig] = None
w2_precision_config: Optional[PrecisionConfig] = None
global_num_experts: int = -1
# ---------------------------------------------------------------------------
# Runner core
# ---------------------------------------------------------------------------
class TritonKernelsRunnerCore(MoeRunnerCore):
"""Execute MoE experts via the external triton_kernels package."""
def run(
self,
runner_input: TritonKernelsRunnerInput,
quant_info: TritonKernelsQuantInfo,
running_state: dict,
) -> TritonKernelsRunnerOutput:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_fused_experts,
triton_kernel_fused_experts_with_bias,
)
hidden_states = runner_input.hidden_states
common_kwargs = dict(
routing_data=runner_input.routing_data,
gather_indx=runner_input.gather_indx,
scatter_indx=None if self.config.no_combine else runner_input.scatter_indx,
inplace=False,
activation=self.config.activation,
apply_router_weight_on_input=self.config.apply_router_weight_on_input,
global_num_experts=quant_info.global_num_experts,
)
has_bias = quant_info.w13_bias is not None or quant_info.w2_bias is not None
if has_bias:
assert (
quant_info.w13_bias is not None and quant_info.w2_bias is not None
), "Bias execution requires both w13_bias and w2_bias"
output = triton_kernel_fused_experts_with_bias(
hidden_states=hidden_states,
w1=quant_info.w13_weight,
w1_pcg=quant_info.w13_precision_config,
b1=quant_info.w13_bias,
w2=quant_info.w2_weight,
w2_pcg=quant_info.w2_precision_config,
b2=quant_info.w2_bias,
gemm1_alpha=self.config.gemm1_alpha,
gemm1_clamp_limit=self.config.gemm1_clamp_limit,
**common_kwargs,
)
else:
output = triton_kernel_fused_experts(
hidden_states=hidden_states,
w1=quant_info.w13_weight,
w2=quant_info.w2_weight,
**common_kwargs,
)
if self.config.no_combine:
tokens = runner_input.hidden_states.shape[0]
hidden = runner_input.hidden_states.shape[-1]
total_rows = output.shape[0]
top_k = total_rows // tokens
output = output.view(tokens, top_k, hidden)
return TritonKernelsRunnerOutput(hidden_states=output)
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON_KERNELS
# ---------------------------------------------------------------------------
# Permute / fused hooks
# ---------------------------------------------------------------------------
@register_pre_permute("standard", "triton_kernel")
def pre_permute_standard_to_triton_kernels(
dispatch_output: "StandardDispatchOutput",
quant_info: TritonKernelsQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> TritonKernelsRunnerInput:
from sglang.srt.layers.moe.topk import TopKOutputChecker
hidden_states = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
assert TopKOutputChecker.format_is_triton_kernels(
topk_output
), "Triton-kernel runner expects TritonKernelTopKOutput"
routing_data, gather_indx, scatter_indx = topk_output
return TritonKernelsRunnerInput(
hidden_states=hidden_states,
routing_data=routing_data,
gather_indx=gather_indx,
scatter_indx=scatter_indx,
)
@register_post_permute("triton_kernel", "standard")
def post_permute_triton_kernels_to_standard(
runner_output: TritonKernelsRunnerOutput,
quant_info: TritonKernelsQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> StandardCombineInput:
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
hidden_states = runner_output.hidden_states
if (
runner_config.routed_scaling_factor is not None
and runner_config.routed_scaling_factor != 1.0
and not runner_config.no_combine
):
hidden_states.mul_(runner_config.routed_scaling_factor)
return StandardCombineInput(hidden_states=hidden_states)
...@@ -28,6 +28,12 @@ class DispatchOutputChecker: ...@@ -28,6 +28,12 @@ class DispatchOutputChecker:
) -> TypeGuard[StandardDispatchOutput]: ) -> TypeGuard[StandardDispatchOutput]:
return dispatch_output.format.is_standard() return dispatch_output.format.is_standard()
@staticmethod
def format_is_triton_kernels(
dispatch_output: DispatchOutput,
) -> TypeGuard[StandardDispatchOutput]:
return dispatch_output.format.is_standard()
@staticmethod @staticmethod
def format_is_deepep_normal( def format_is_deepep_normal(
dispatch_output: DispatchOutput, dispatch_output: DispatchOutput,
......
...@@ -88,7 +88,7 @@ class StandardDispatcher(BaseDispatcher): ...@@ -88,7 +88,7 @@ class StandardDispatcher(BaseDispatcher):
topk_output = topk_output._replace( topk_output = topk_output._replace(
topk_ids=self.local_expert_mapping[topk_output.topk_ids] topk_ids=self.local_expert_mapping[topk_output.topk_ids]
) )
elif TopKOutputChecker.format_is_triton_kernel(topk_output): elif TopKOutputChecker.format_is_triton_kernels(topk_output):
raise NotImplementedError() raise NotImplementedError()
return StandardDispatchOutput( return StandardDispatchOutput(
......
...@@ -111,10 +111,10 @@ class TopKOutputChecker: ...@@ -111,10 +111,10 @@ class TopKOutputChecker:
return topk_output.format.is_standard() return topk_output.format.is_standard()
@staticmethod @staticmethod
def format_is_triton_kernel( def format_is_triton_kernels(
topk_output: TopKOutput, topk_output: TopKOutput,
) -> TypeGuard[TritonKernelTopKOutput]: ) -> TypeGuard[TritonKernelTopKOutput]:
return topk_output.format.is_triton_kernel() return topk_output.format.is_triton_kernels()
@staticmethod @staticmethod
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]: def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
...@@ -129,7 +129,7 @@ class TopKOutputFormat(Enum): ...@@ -129,7 +129,7 @@ class TopKOutputFormat(Enum):
def is_standard(self) -> bool: def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD return self == TopKOutputFormat.STANDARD
def is_triton_kernel(self) -> bool: def is_triton_kernels(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL return self == TopKOutputFormat.TRITON_KERNEL
def is_bypassed(self) -> bool: def is_bypassed(self) -> bool:
...@@ -254,7 +254,7 @@ class TopK(CustomOp): ...@@ -254,7 +254,7 @@ class TopK(CustomOp):
) -> TopKOutput: ) -> TopKOutput:
if self.topk_config.output_format is not None: if self.topk_config.output_format is not None:
output_format = self.topk_config.output_format output_format = self.topk_config.output_format
elif get_moe_runner_backend().is_triton_kernel(): elif get_moe_runner_backend().is_triton_kernels():
output_format = TopKOutputFormat.TRITON_KERNEL output_format = TopKOutputFormat.TRITON_KERNEL
elif ( elif (
should_use_flashinfer_trtllm_moe() should_use_flashinfer_trtllm_moe()
......
...@@ -51,7 +51,7 @@ class MoeRunnerBackend(Enum): ...@@ -51,7 +51,7 @@ class MoeRunnerBackend(Enum):
AUTO = "auto" AUTO = "auto"
DEEP_GEMM = "deep_gemm" DEEP_GEMM = "deep_gemm"
TRITON = "triton" TRITON = "triton"
TRITON_KERNEL = "triton_kernel" TRITON_KERNELS = "triton_kernel"
FLASHINFER_TRTLLM = "flashinfer_trtllm" FLASHINFER_TRTLLM = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass" FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4" FLASHINFER_MXFP4 = "flashinfer_mxfp4"
...@@ -67,8 +67,8 @@ class MoeRunnerBackend(Enum): ...@@ -67,8 +67,8 @@ class MoeRunnerBackend(Enum):
def is_triton(self): def is_triton(self):
return self == MoeRunnerBackend.TRITON return self == MoeRunnerBackend.TRITON
def is_triton_kernel(self): def is_triton_kernels(self):
return self == MoeRunnerBackend.TRITON_KERNEL return self == MoeRunnerBackend.TRITON_KERNELS
def is_flashinfer_trtllm(self): def is_flashinfer_trtllm(self):
return self == MoeRunnerBackend.FLASHINFER_TRTLLM return self == MoeRunnerBackend.FLASHINFER_TRTLLM
......
...@@ -261,26 +261,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -261,26 +261,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.prefix = prefix self.prefix = prefix
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
self.with_bias = False self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4() self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = ( self.flashinfer_mxfp4_moe_precision = (
get_global_server_args().flashinfer_mxfp4_moe_precision get_global_server_args().flashinfer_mxfp4_moe_precision
) )
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
)
self.triton_kernel_moe_forward = _tk_forward
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -600,7 +587,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -600,7 +587,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
): ):
self.moe_runner_config = moe_runner_config self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) backend = get_moe_runner_backend()
if backend.is_auto():
backend = (
MoeRunnerBackend.TRITON_KERNELS
if self.use_triton_kernels
else MoeRunnerBackend.TRITON
)
self.runner = MoeRunner(backend, moe_runner_config)
def apply( def apply(
self, self,
...@@ -677,31 +671,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -677,31 +671,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)[0] )[0]
return StandardCombineInput(hidden_states=trtllm_gen_output) return StandardCombineInput(hidden_states=trtllm_gen_output)
if self.use_triton_kernels: backend = self.runner.runner_backend
if backend.is_triton_kernels():
from sglang.srt.layers.moe.moe_runner.triton_kernels import (
TritonKernelsQuantInfo,
)
assert ( assert (
layer.moe_ep_size == 1 layer.moe_ep_size == 1
), "Expert parallel is not supported when using triton kernels" ), "Expert parallel is not supported when using triton kernels"
if self.with_bias: quant_info = TritonKernelsQuantInfo(
output = self.triton_kernel_moe_with_bias_forward( w13_weight=(
hidden_states=x, self.w13_weight_triton_tensor
w1=self.w13_weight_triton_tensor, if self.w13_weight_triton_tensor is not None
w1_pcg=self.w13_precision_config, else layer.w13_weight
w2=self.w2_weight_triton_tensor, ),
w2_pcg=self.w2_precision_config, w2_weight=(
b1=layer.w13_weight_bias, self.w2_weight_triton_tensor
b2=layer.w2_weight_bias, if self.w2_weight_triton_tensor is not None
topk_output=topk_output, else layer.w2_weight
moe_runner_config=moe_runner_config, ),
) w13_bias=getattr(layer, "w13_weight_bias", None),
else: w2_bias=getattr(layer, "w2_weight_bias", None),
output = self.triton_kernel_moe_forward( w13_precision_config=getattr(self, "w13_precision_config", None),
hidden_states=x, w2_precision_config=getattr(self, "w2_precision_config", None),
w1=layer.w13_weight, )
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
else: else:
quant_info = TritonMoeQuantInfo( quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight, w13_weight=layer.w13_weight,
...@@ -709,7 +703,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -709,7 +703,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b13=getattr(layer, "w13_weight_bias", None), b13=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None), b2=getattr(layer, "w2_weight_bias", None),
) )
return self.runner.run(dispatch_output, quant_info) return self.runner.run(dispatch_output, quant_info)
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
......
...@@ -8,7 +8,12 @@ from torch.nn.parameter import Parameter ...@@ -8,7 +8,12 @@ from torch.nn.parameter import Parameter
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
get_moe_runner_backend,
)
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -115,13 +120,15 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -115,13 +120,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
x_shapes = x.shape x_shapes = x.shape
if len(x_shapes) == 3: if len(x_shapes) == 3:
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output = torch.ops.sgl_kernel.weight_packed_linear( output = torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni x,
layer.weight,
bias,
True, # is_vnni
) )
if len(x_shapes) == 3: if len(x_shapes) == 3:
output = output.view(x_shapes[0], x_shapes[1], -1) output = output.view(x_shapes[0], x_shapes[1], -1)
...@@ -138,19 +145,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -138,19 +145,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.use_triton_kernels = use_triton_kernels self.use_triton_kernels = use_triton_kernels
self.with_bias = False self.with_bias = False
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and use_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
)
self.triton_kernel_moe_forward = _tk_forward
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -231,14 +225,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -231,14 +225,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
): ):
self.moe_runner_config = moe_runner_config self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) backend = get_moe_runner_backend()
if backend.is_auto():
backend = (
MoeRunnerBackend.TRITON_KERNELS
if self.use_triton_kernels
else MoeRunnerBackend.TRITON
)
self.runner = MoeRunner(backend, moe_runner_config)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput, dispatch_output: StandardDispatchOutput,
) -> CombineInput: ) -> CombineInput:
return self.forward( return self.forward(
layer=layer, layer=layer,
dispatch_output=dispatch_output, dispatch_output=dispatch_output,
...@@ -249,7 +249,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -249,7 +249,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput, dispatch_output: StandardDispatchOutput,
) -> CombineInput: ) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states x = dispatch_output.hidden_states
...@@ -257,30 +256,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -257,30 +256,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
moe_runner_config = self.moe_runner_config moe_runner_config = self.moe_runner_config
if self.use_triton_kernels: backend = self.runner.runner_backend
if self.with_bias: if backend.is_triton_kernels():
assert self.triton_kernel_moe_with_bias_forward is not None from sglang.srt.layers.moe.moe_runner.triton_kernels import (
output = self.triton_kernel_moe_with_bias_forward( TritonKernelsQuantInfo,
hidden_states=x, )
w1=layer.w13_weight,
w2=layer.w2_weight, quant_info = TritonKernelsQuantInfo(
b1=layer.w13_weight_bias, w13_weight=layer.w13_weight,
b2=layer.w2_weight_bias, w2_weight=layer.w2_weight,
topk_output=topk_output, w13_bias=getattr(layer, "w13_weight_bias", None),
moe_runner_config=moe_runner_config, w2_bias=getattr(layer, "w2_weight_bias", None),
w1_pcg=None, )
w2_pcg=None, return self.runner.run(dispatch_output, quant_info)
)
else:
assert self.triton_kernel_moe_forward is not None
output = self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
else: else:
if _use_aiter: if _use_aiter:
assert not moe_runner_config.no_combine, "unsupported" assert not moe_runner_config.no_combine, "unsupported"
...@@ -311,7 +299,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -311,7 +299,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) )
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
else: else:
quant_info = TritonMoeQuantInfo( quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight, w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight, w2_weight=layer.w2_weight,
...@@ -325,7 +312,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -325,7 +312,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput, dispatch_output: StandardDispatchOutput,
) -> CombineInput: ) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states x = dispatch_output.hidden_states
...@@ -380,7 +366,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -380,7 +366,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput, dispatch_output: StandardDispatchOutput,
) -> CombineInput: ) -> CombineInput:
import torch_npu import torch_npu
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
......
...@@ -5,11 +5,10 @@ import torch.nn.functional as F ...@@ -5,11 +5,10 @@ import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
triton_kernel_moe_forward, from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsQuantInfo
) from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.layers.moe.topk import TopK
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -55,6 +54,7 @@ class TestFusedMOE(CustomTestCase): ...@@ -55,6 +54,7 @@ class TestFusedMOE(CustomTestCase):
w2, w2,
score, score,
topk, topk,
return_per_expert: bool = False,
): ):
B, D = a.shape B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
...@@ -78,9 +78,14 @@ class TestFusedMOE(CustomTestCase): ...@@ -78,9 +78,14 @@ class TestFusedMOE(CustomTestCase):
a[mask] @ w1_compute[i].transpose(0, 1) a[mask] @ w1_compute[i].transpose(0, 1)
) @ w2_compute[i].transpose(0, 1) ) @ w2_compute[i].transpose(0, 1)
return ( weighted = out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) out.dtype
).sum(dim=1) )
if return_per_expert:
return weighted
return weighted.sum(dim=1)
def _test_case(self, m, n, k, e, topk, dtype): def _test_case(self, m, n, k, e, topk, dtype):
rtol, atol = self.get_tolerance(dtype) rtol, atol = self.get_tolerance(dtype)
...@@ -99,20 +104,43 @@ class TestFusedMOE(CustomTestCase): ...@@ -99,20 +104,43 @@ class TestFusedMOE(CustomTestCase):
renormalize=False, renormalize=False,
use_grouped_topk=False, use_grouped_topk=False,
) )
topk_op.use_triton_kernels = True topk_op.topk_config.output_format = TopKOutputFormat.TRITON_KERNEL
triton_topk_output = topk_op.forward_cuda( triton_topk_output = topk_op.forward_cuda(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
) )
moe_runner_config = MoeRunnerConfig( quant_info = TritonKernelsQuantInfo(w13_weight=w1_tri, w2_weight=w2_tri)
inplace=False,
dispatch_output = StandardDispatchOutput(
hidden_states=a, topk_output=triton_topk_output
)
torch_per_expert = self.torch_naive_moe(
a, w1, w2, score, topk, return_per_expert=True
)
torch_combined = torch_per_expert.sum(dim=1)
def run_runner(config):
runner = MoeRunner(MoeRunnerBackend.TRITON_KERNELS, config)
result = runner.run(dispatch_output, quant_info)
return result.hidden_states
# Combined output (no_combine=False)
non_fused_config = MoeRunnerConfig(inplace=False)
non_fused_output = run_runner(non_fused_config)
torch.testing.assert_close(
non_fused_output, torch_combined, rtol=rtol, atol=atol
)
# Per-expert output (no_combine=True)
non_fused_no_combine_config = MoeRunnerConfig(
inplace=False, no_combine=True, top_k=topk
) )
triton_output = triton_kernel_moe_forward( non_fused_no_combine_output = run_runner(non_fused_no_combine_config)
a, w1_tri, w2_tri, triton_topk_output, moe_runner_config torch.testing.assert_close(
non_fused_no_combine_output, torch_per_expert, rtol=rtol, atol=atol
) )
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
def test_various_configurations(self): def test_various_configurations(self):
m_values = [1, 32, 64, 256] m_values = [1, 32, 64, 256]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment