"git@developer.sourcefind.cn:change/sglang.git" did not exist on "3d40794fcf3678a713c0054ae9d59dafab979bcf"
Unverified Commit b3eac168 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Support triton kernels v3.4.0 for fused_moe (#8258)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatarCheng Wan <cwan@x.ai>
Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 10ee8955
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035 # Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
from typing import Optional
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import torch import torch
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
from triton_kernels.matmul_ogs import matmul_ogs from triton_kernels.matmul_ogs import matmul_ogs
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from sglang.srt.utils import direct_register_custom_op from sglang.srt.utils import direct_register_custom_op
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
def triton_kernel_moe_forward( def triton_kernel_moe_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
gating_output: torch.Tensor, topk_output: TopKOutput,
topk: int,
renormalize: bool,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
...@@ -30,9 +34,8 @@ def triton_kernel_moe_forward( ...@@ -30,9 +34,8 @@ def triton_kernel_moe_forward(
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if not renormalize: assert topk_output.format.is_triton_kernel()
gating_output = torch.softmax(gating_output, dim=-1) routing_data, gather_idx, scatter_idx = topk_output
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
return triton_kernel_fused_experts( return triton_kernel_fused_experts(
hidden_states, hidden_states,
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
from __future__ import annotations from __future__ import annotations
import math import math
from typing import Callable, NamedTuple, Optional from enum import Enum, auto
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import ( ...@@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo, ExpertLocationDispatchInfo,
topk_ids_logical_to_physical, topk_ids_logical_to_physical,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
...@@ -37,6 +39,12 @@ from sglang.srt.utils import ( ...@@ -37,6 +39,12 @@ from sglang.srt.utils import (
is_npu, is_npu,
) )
try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError:
pass
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu = is_cpu() _is_cpu = is_cpu()
...@@ -58,15 +66,58 @@ if _is_npu: ...@@ -58,15 +66,58 @@ if _is_npu:
import torch_npu import torch_npu
class TopKOutput(NamedTuple): # -------------------------------- TopKOutput ---------------------------------------
class TopKOutputFormat(Enum):
STANDARD = auto()
TRITON_KERNEL = auto()
def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD
def is_triton_kernel(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL
@runtime_checkable
class TopKOutput(Protocol):
"""Protocol for top-k outputs in different formats."""
@property
def format(self) -> TopKOutputFormat:
"""The format of the output."""
...
class StandardTopKOutput(NamedTuple):
"""Standard top-k output format."""
topk_weights: torch.Tensor topk_weights: torch.Tensor
topk_ids: torch.Tensor topk_ids: torch.Tensor
router_logits: torch.Tensor router_logits: torch.Tensor
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.STANDARD
class TopK(CustomOp):
# TODO(ch-wan): support triton_kernels class TritonKernelTopKOutput(NamedTuple):
"""Triton kernel top-k output format."""
routing_data: RoutingData
gather_indx: GatherIndx
scatter_indx: ScatterIndx
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.TRITON_KERNEL
# -------------------------------- TopK ---------------------------------------
class TopK(CustomOp):
def __init__( def __init__(
self, self,
...@@ -97,6 +148,8 @@ class TopK(CustomOp): ...@@ -97,6 +148,8 @@ class TopK(CustomOp):
self.correction_bias = correction_bias self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
def forward_native( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -131,23 +184,29 @@ class TopK(CustomOp): ...@@ -131,23 +184,29 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput: ) -> TopKOutput:
torch_native = False if self.use_triton_kernels:
return select_experts( routing_data, gather_idx, scatter_idx = routing(
hidden_states=hidden_states, router_logits, self.top_k, self.renormalize
router_logits=router_logits, )
top_k=self.top_k, return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
use_grouped_topk=self.use_grouped_topk, else:
renormalize=self.renormalize, torch_native = False
topk_group=self.topk_group, return select_experts(
num_expert_group=self.num_expert_group, hidden_states=hidden_states,
num_fused_shared_experts=self.num_fused_shared_experts, router_logits=router_logits,
custom_routing_function=self.custom_routing_function, top_k=self.top_k,
correction_bias=self.correction_bias, use_grouped_topk=self.use_grouped_topk,
torch_native=torch_native, renormalize=self.renormalize,
routed_scaling_factor=self.routed_scaling_factor, topk_group=self.topk_group,
num_token_non_padded=num_token_non_padded, num_expert_group=self.num_expert_group,
expert_location_dispatch_info=expert_location_dispatch_info, num_fused_shared_experts=self.num_fused_shared_experts,
) custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cpu( def forward_cpu(
self, self,
...@@ -217,6 +276,9 @@ class TopK(CustomOp): ...@@ -217,6 +276,9 @@ class TopK(CustomOp):
) )
# ------------------------------- TopK implementation -------------------------------------
def fused_topk_torch_native( def fused_topk_torch_native(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -680,4 +742,4 @@ def select_experts( ...@@ -680,4 +742,4 @@ def select_experts(
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
return TopKOutput(topk_weights, topk_ids, router_logits) return StandardTopKOutput(topk_weights, topk_ids, router_logits)
...@@ -130,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -130,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
super().__init__() super().__init__()
self.use_triton_kernels = use_triton_kernels self.use_triton_kernels = use_triton_kernels
self.triton_kernel_moe_forward = None
if torch.cuda.is_available() and has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
self.triton_kernel_moe_forward = _tk_forward
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -229,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -229,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_triton_kernels: if self.use_triton_kernels:
# TODO(ch-wan): re-enable the Triton kernel return self.triton_kernel_moe_forward(
raise NotImplementedError("The Triton kernel is temporarily disabled.") hidden_states=x,
# return triton_kernel_moe_forward( w1=layer.w13_weight,
# hidden_states=x, w2=layer.w2_weight,
# w1=layer.w13_weight, topk_output=topk_output,
# w2=layer.w2_weight, )
# gating_output=router_logits,
# topk=top_k,
# renormalize=renormalize,
# )
else: else:
if _use_aiter: if _use_aiter:
assert not no_combine, "unsupported" assert not no_combine, "unsupported"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment