Unverified Commit b3eac168 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Support triton kernels v3.4.0 for fused_moe (#8258)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatarCheng Wan <cwan@x.ai>
Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 10ee8955
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
from typing import Optional
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import torch
from sgl_kernel import gelu_and_mul, silu_and_mul
from triton_kernels.matmul_ogs import matmul_ogs
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from sglang.srt.utils import direct_register_custom_op
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
......@@ -30,9 +34,8 @@ def triton_kernel_moe_forward(
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
if not renormalize:
gating_output = torch.softmax(gating_output, dim=-1)
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
assert topk_output.format.is_triton_kernel()
routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts(
hidden_states,
......
......@@ -15,7 +15,8 @@
from __future__ import annotations
import math
from typing import Callable, NamedTuple, Optional
from enum import Enum, auto
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
import torch
import torch.nn.functional as F
......@@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo,
topk_ids_logical_to_physical,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
......@@ -37,6 +39,12 @@ from sglang.srt.utils import (
is_npu,
)
try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError:
pass
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
......@@ -58,15 +66,58 @@ if _is_npu:
import torch_npu
class TopKOutput(NamedTuple):
# -------------------------------- TopKOutput ---------------------------------------
class TopKOutputFormat(Enum):
STANDARD = auto()
TRITON_KERNEL = auto()
def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD
def is_triton_kernel(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL
@runtime_checkable
class TopKOutput(Protocol):
"""Protocol for top-k outputs in different formats."""
@property
def format(self) -> TopKOutputFormat:
"""The format of the output."""
...
class StandardTopKOutput(NamedTuple):
"""Standard top-k output format."""
topk_weights: torch.Tensor
topk_ids: torch.Tensor
router_logits: torch.Tensor
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.STANDARD
class TopK(CustomOp):
# TODO(ch-wan): support triton_kernels
class TritonKernelTopKOutput(NamedTuple):
"""Triton kernel top-k output format."""
routing_data: RoutingData
gather_indx: GatherIndx
scatter_indx: ScatterIndx
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.TRITON_KERNEL
# -------------------------------- TopK ---------------------------------------
class TopK(CustomOp):
def __init__(
self,
......@@ -97,6 +148,8 @@ class TopK(CustomOp):
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
def forward_native(
self,
hidden_states: torch.Tensor,
......@@ -131,23 +184,29 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
if self.use_triton_kernels:
routing_data, gather_idx, scatter_idx = routing(
router_logits, self.top_k, self.renormalize
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
else:
torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cpu(
self,
......@@ -217,6 +276,9 @@ class TopK(CustomOp):
)
# ------------------------------- TopK implementation -------------------------------------
def fused_topk_torch_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......@@ -680,4 +742,4 @@ def select_experts(
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
return TopKOutput(topk_weights, topk_ids, router_logits)
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
......@@ -130,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
super().__init__()
self.use_triton_kernels = use_triton_kernels
self.triton_kernel_moe_forward = None
if torch.cuda.is_available() and has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
self.triton_kernel_moe_forward = _tk_forward
def create_weights(
self,
layer: torch.nn.Module,
......@@ -229,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor:
if self.use_triton_kernels:
# TODO(ch-wan): re-enable the Triton kernel
raise NotImplementedError("The Triton kernel is temporarily disabled.")
# return triton_kernel_moe_forward(
# hidden_states=x,
# w1=layer.w13_weight,
# w2=layer.w2_weight,
# gating_output=router_logits,
# topk=top_k,
# renormalize=renormalize,
# )
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
)
else:
if _use_aiter:
assert not no_combine, "unsupported"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment