Unverified Commit 62eff37b authored by Jonah Bernard's avatar Jonah Bernard Committed by GitHub
Browse files

Refactor Triton-kernel MoE runner integration (#11795)

parent 47e12e08
......@@ -172,7 +172,7 @@ class FusedMoE(torch.nn.Module):
self.reduce_results = reduce_results
self.use_presharded_weights = use_presharded_weights
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
self.quant_config = quant_config
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
......
......@@ -47,7 +47,7 @@ def triton_kernel_moe_forward(
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
assert TopKOutputChecker.format_is_triton_kernels(topk_output)
routing_data, gather_idx, scatter_idx = topk_output
......@@ -172,6 +172,7 @@ def triton_kernel_moe_with_bias_forward(
b2: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
......@@ -184,7 +185,7 @@ def triton_kernel_moe_with_bias_forward(
) -> torch.Tensor:
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
assert TopKOutputChecker.format_is_triton_kernels(topk_output)
routing_data, gather_idx, scatter_idx = topk_output
......@@ -201,6 +202,7 @@ def triton_kernel_moe_with_bias_forward(
scatter_indx=scatter_idx,
inplace=False, # triton kernel doesn't support inplace
activation=moe_runner_config.activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
......@@ -228,6 +230,7 @@ def triton_kernel_fused_experts_with_bias(
scatter_indx: ScatterIndx,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
......@@ -296,7 +299,7 @@ def triton_kernel_fused_experts_with_bias(
routing_data,
gather_indx=gather_indx,
precision_config=w1_pcg,
gammas=None,
gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
fused_activation=act,
)
......@@ -307,5 +310,5 @@ def triton_kernel_fused_experts_with_bias(
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_pcg,
gammas=routing_data.gate_scal,
gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
)
......@@ -11,6 +11,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
)
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsRunnerCore
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
if TYPE_CHECKING:
......@@ -31,6 +32,8 @@ class MoeRunner:
if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config)
elif runner_backend.is_triton_kernels():
self.runner_core = TritonKernelsRunnerCore(config)
elif runner_backend.is_deep_gemm():
self.runner_core = DeepGemmRunnerCore(config)
else:
......
"""Triton kernels MoE runner backend skeleton."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
MoeRunnerCore,
RunnerInput,
RunnerOutput,
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
# ---------------------------------------------------------------------------
# Runner IO dataclasses
# ---------------------------------------------------------------------------
@dataclass
class TritonKernelsRunnerInput(RunnerInput):
"""Input bundle passed to the triton-kernels runner core."""
hidden_states: torch.Tensor
routing_data: "RoutingData"
gather_indx: "GatherIndx"
scatter_indx: "ScatterIndx"
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON_KERNELS
@dataclass
class TritonKernelsRunnerOutput(RunnerOutput):
"""Output bundle returned from the triton-kernels runner core."""
hidden_states: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON_KERNELS
@dataclass
class TritonKernelsQuantInfo(MoeQuantInfo):
"""Quantization payload consumed by the triton-kernels backend."""
w13_weight: torch.Tensor
w2_weight: torch.Tensor
w13_bias: Optional[torch.Tensor] = None
w2_bias: Optional[torch.Tensor] = None
w13_precision_config: Optional[PrecisionConfig] = None
w2_precision_config: Optional[PrecisionConfig] = None
global_num_experts: int = -1
# ---------------------------------------------------------------------------
# Runner core
# ---------------------------------------------------------------------------
class TritonKernelsRunnerCore(MoeRunnerCore):
"""Execute MoE experts via the external triton_kernels package."""
def run(
self,
runner_input: TritonKernelsRunnerInput,
quant_info: TritonKernelsQuantInfo,
running_state: dict,
) -> TritonKernelsRunnerOutput:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_fused_experts,
triton_kernel_fused_experts_with_bias,
)
hidden_states = runner_input.hidden_states
common_kwargs = dict(
routing_data=runner_input.routing_data,
gather_indx=runner_input.gather_indx,
scatter_indx=None if self.config.no_combine else runner_input.scatter_indx,
inplace=False,
activation=self.config.activation,
apply_router_weight_on_input=self.config.apply_router_weight_on_input,
global_num_experts=quant_info.global_num_experts,
)
has_bias = quant_info.w13_bias is not None or quant_info.w2_bias is not None
if has_bias:
assert (
quant_info.w13_bias is not None and quant_info.w2_bias is not None
), "Bias execution requires both w13_bias and w2_bias"
output = triton_kernel_fused_experts_with_bias(
hidden_states=hidden_states,
w1=quant_info.w13_weight,
w1_pcg=quant_info.w13_precision_config,
b1=quant_info.w13_bias,
w2=quant_info.w2_weight,
w2_pcg=quant_info.w2_precision_config,
b2=quant_info.w2_bias,
gemm1_alpha=self.config.gemm1_alpha,
gemm1_clamp_limit=self.config.gemm1_clamp_limit,
**common_kwargs,
)
else:
output = triton_kernel_fused_experts(
hidden_states=hidden_states,
w1=quant_info.w13_weight,
w2=quant_info.w2_weight,
**common_kwargs,
)
if self.config.no_combine:
tokens = runner_input.hidden_states.shape[0]
hidden = runner_input.hidden_states.shape[-1]
total_rows = output.shape[0]
top_k = total_rows // tokens
output = output.view(tokens, top_k, hidden)
return TritonKernelsRunnerOutput(hidden_states=output)
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.TRITON_KERNELS
# ---------------------------------------------------------------------------
# Permute / fused hooks
# ---------------------------------------------------------------------------
@register_pre_permute("standard", "triton_kernel")
def pre_permute_standard_to_triton_kernels(
dispatch_output: "StandardDispatchOutput",
quant_info: TritonKernelsQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> TritonKernelsRunnerInput:
from sglang.srt.layers.moe.topk import TopKOutputChecker
hidden_states = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
assert TopKOutputChecker.format_is_triton_kernels(
topk_output
), "Triton-kernel runner expects TritonKernelTopKOutput"
routing_data, gather_indx, scatter_indx = topk_output
return TritonKernelsRunnerInput(
hidden_states=hidden_states,
routing_data=routing_data,
gather_indx=gather_indx,
scatter_indx=scatter_indx,
)
@register_post_permute("triton_kernel", "standard")
def post_permute_triton_kernels_to_standard(
runner_output: TritonKernelsRunnerOutput,
quant_info: TritonKernelsQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> StandardCombineInput:
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
hidden_states = runner_output.hidden_states
if (
runner_config.routed_scaling_factor is not None
and runner_config.routed_scaling_factor != 1.0
and not runner_config.no_combine
):
hidden_states.mul_(runner_config.routed_scaling_factor)
return StandardCombineInput(hidden_states=hidden_states)
......@@ -28,6 +28,12 @@ class DispatchOutputChecker:
) -> TypeGuard[StandardDispatchOutput]:
return dispatch_output.format.is_standard()
@staticmethod
def format_is_triton_kernels(
dispatch_output: DispatchOutput,
) -> TypeGuard[StandardDispatchOutput]:
return dispatch_output.format.is_standard()
@staticmethod
def format_is_deepep_normal(
dispatch_output: DispatchOutput,
......
......@@ -88,7 +88,7 @@ class StandardDispatcher(BaseDispatcher):
topk_output = topk_output._replace(
topk_ids=self.local_expert_mapping[topk_output.topk_ids]
)
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
elif TopKOutputChecker.format_is_triton_kernels(topk_output):
raise NotImplementedError()
return StandardDispatchOutput(
......
......@@ -111,10 +111,10 @@ class TopKOutputChecker:
return topk_output.format.is_standard()
@staticmethod
def format_is_triton_kernel(
def format_is_triton_kernels(
topk_output: TopKOutput,
) -> TypeGuard[TritonKernelTopKOutput]:
return topk_output.format.is_triton_kernel()
return topk_output.format.is_triton_kernels()
@staticmethod
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
......@@ -129,7 +129,7 @@ class TopKOutputFormat(Enum):
def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD
def is_triton_kernel(self) -> bool:
def is_triton_kernels(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL
def is_bypassed(self) -> bool:
......@@ -254,7 +254,7 @@ class TopK(CustomOp):
) -> TopKOutput:
if self.topk_config.output_format is not None:
output_format = self.topk_config.output_format
elif get_moe_runner_backend().is_triton_kernel():
elif get_moe_runner_backend().is_triton_kernels():
output_format = TopKOutputFormat.TRITON_KERNEL
elif (
should_use_flashinfer_trtllm_moe()
......
......@@ -51,7 +51,7 @@ class MoeRunnerBackend(Enum):
AUTO = "auto"
DEEP_GEMM = "deep_gemm"
TRITON = "triton"
TRITON_KERNEL = "triton_kernel"
TRITON_KERNELS = "triton_kernel"
FLASHINFER_TRTLLM = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
......@@ -67,8 +67,8 @@ class MoeRunnerBackend(Enum):
def is_triton(self):
return self == MoeRunnerBackend.TRITON
def is_triton_kernel(self):
return self == MoeRunnerBackend.TRITON_KERNEL
def is_triton_kernels(self):
return self == MoeRunnerBackend.TRITON_KERNELS
def is_flashinfer_trtllm(self):
return self == MoeRunnerBackend.FLASHINFER_TRTLLM
......
......@@ -261,26 +261,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.prefix = prefix
self.topk_indices_dtype = None
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = (
get_global_server_args().flashinfer_mxfp4_moe_precision
)
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
)
self.triton_kernel_moe_forward = _tk_forward
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
def create_weights(
self,
layer: torch.nn.Module,
......@@ -600,7 +587,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
backend = get_moe_runner_backend()
if backend.is_auto():
backend = (
MoeRunnerBackend.TRITON_KERNELS
if self.use_triton_kernels
else MoeRunnerBackend.TRITON
)
self.runner = MoeRunner(backend, moe_runner_config)
def apply(
self,
......@@ -677,31 +671,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)[0]
return StandardCombineInput(hidden_states=trtllm_gen_output)
if self.use_triton_kernels:
backend = self.runner.runner_backend
if backend.is_triton_kernels():
from sglang.srt.layers.moe.moe_runner.triton_kernels import (
TritonKernelsQuantInfo,
)
assert (
layer.moe_ep_size == 1
), "Expert parallel is not supported when using triton kernels"
if self.with_bias:
output = self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w1_pcg=self.w13_precision_config,
w2=self.w2_weight_triton_tensor,
w2_pcg=self.w2_precision_config,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
else:
output = self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
quant_info = TritonKernelsQuantInfo(
w13_weight=(
self.w13_weight_triton_tensor
if self.w13_weight_triton_tensor is not None
else layer.w13_weight
),
w2_weight=(
self.w2_weight_triton_tensor
if self.w2_weight_triton_tensor is not None
else layer.w2_weight
),
w13_bias=getattr(layer, "w13_weight_bias", None),
w2_bias=getattr(layer, "w2_weight_bias", None),
w13_precision_config=getattr(self, "w13_precision_config", None),
w2_precision_config=getattr(self, "w2_precision_config", None),
)
else:
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
......@@ -709,7 +703,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b13=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None),
)
return self.runner.run(dispatch_output, quant_info)
return self.runner.run(dispatch_output, quant_info)
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
......
......@@ -8,7 +8,12 @@ from torch.nn.parameter import Parameter
from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
get_moe_runner_backend,
)
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
......@@ -115,13 +120,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if use_intel_amx_backend(layer):
x_shapes = x.shape
if len(x_shapes) == 3:
x = x.view(-1, x.shape[-1])
output = torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni
x,
layer.weight,
bias,
True, # is_vnni
)
if len(x_shapes) == 3:
output = output.view(x_shapes[0], x_shapes[1], -1)
......@@ -138,19 +145,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.use_triton_kernels = use_triton_kernels
self.with_bias = False
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and use_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
)
self.triton_kernel_moe_forward = _tk_forward
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
def create_weights(
self,
layer: torch.nn.Module,
......@@ -231,14 +225,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
backend = get_moe_runner_backend()
if backend.is_auto():
backend = (
MoeRunnerBackend.TRITON_KERNELS
if self.use_triton_kernels
else MoeRunnerBackend.TRITON
)
self.runner = MoeRunner(backend, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
return self.forward(
layer=layer,
dispatch_output=dispatch_output,
......@@ -249,7 +249,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
......@@ -257,30 +256,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
moe_runner_config = self.moe_runner_config
if self.use_triton_kernels:
if self.with_bias:
assert self.triton_kernel_moe_with_bias_forward is not None
output = self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
w1_pcg=None,
w2_pcg=None,
)
else:
assert self.triton_kernel_moe_forward is not None
output = self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
return StandardCombineInput(hidden_states=output)
backend = self.runner.runner_backend
if backend.is_triton_kernels():
from sglang.srt.layers.moe.moe_runner.triton_kernels import (
TritonKernelsQuantInfo,
)
quant_info = TritonKernelsQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
w13_bias=getattr(layer, "w13_weight_bias", None),
w2_bias=getattr(layer, "w2_weight_bias", None),
)
return self.runner.run(dispatch_output, quant_info)
else:
if _use_aiter:
assert not moe_runner_config.no_combine, "unsupported"
......@@ -311,7 +299,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
return StandardCombineInput(hidden_states=output)
else:
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
......@@ -325,7 +312,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
......@@ -380,7 +366,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
import torch_npu
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
......
......@@ -5,11 +5,10 @@ import torch.nn.functional as F
from tqdm import tqdm
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsQuantInfo
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.test.test_utils import CustomTestCase
......@@ -55,6 +54,7 @@ class TestFusedMOE(CustomTestCase):
w2,
score,
topk,
return_per_expert: bool = False,
):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
......@@ -78,9 +78,14 @@ class TestFusedMOE(CustomTestCase):
a[mask] @ w1_compute[i].transpose(0, 1)
) @ w2_compute[i].transpose(0, 1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
weighted = out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(
out.dtype
)
if return_per_expert:
return weighted
return weighted.sum(dim=1)
def _test_case(self, m, n, k, e, topk, dtype):
rtol, atol = self.get_tolerance(dtype)
......@@ -99,20 +104,43 @@ class TestFusedMOE(CustomTestCase):
renormalize=False,
use_grouped_topk=False,
)
topk_op.use_triton_kernels = True
topk_op.topk_config.output_format = TopKOutputFormat.TRITON_KERNEL
triton_topk_output = topk_op.forward_cuda(
hidden_states=a,
router_logits=score,
)
moe_runner_config = MoeRunnerConfig(
inplace=False,
quant_info = TritonKernelsQuantInfo(w13_weight=w1_tri, w2_weight=w2_tri)
dispatch_output = StandardDispatchOutput(
hidden_states=a, topk_output=triton_topk_output
)
torch_per_expert = self.torch_naive_moe(
a, w1, w2, score, topk, return_per_expert=True
)
torch_combined = torch_per_expert.sum(dim=1)
def run_runner(config):
runner = MoeRunner(MoeRunnerBackend.TRITON_KERNELS, config)
result = runner.run(dispatch_output, quant_info)
return result.hidden_states
# Combined output (no_combine=False)
non_fused_config = MoeRunnerConfig(inplace=False)
non_fused_output = run_runner(non_fused_config)
torch.testing.assert_close(
non_fused_output, torch_combined, rtol=rtol, atol=atol
)
# Per-expert output (no_combine=True)
non_fused_no_combine_config = MoeRunnerConfig(
inplace=False, no_combine=True, top_k=topk
)
triton_output = triton_kernel_moe_forward(
a, w1_tri, w2_tri, triton_topk_output, moe_runner_config
non_fused_no_combine_output = run_runner(non_fused_no_combine_config)
torch.testing.assert_close(
non_fused_no_combine_output, torch_per_expert, rtol=rtol, atol=atol
)
torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol)
def test_various_configurations(self):
m_values = [1, 32, 64, 256]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment