Unverified Commit 38076dea authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

apply fused moe gate in ds v3/r1 (#5371)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 5e0a9b09
......@@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================
import math
import os
from typing import Callable, Optional
......@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import moe_fused_gate
expert_distribution_recorder = ExpertDistributionRecorder()
......@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......@@ -220,23 +227,37 @@ def biased_grouped_topk(
compiled: bool = True,
n_share_experts_fusion: int = 0,
):
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
if (
_is_cuda
and n_share_experts_fusion == 0
and is_power_of_two(correction_bias.shape[0])
):
return moe_fused_gate(
gating_output,
correction_bias,
num_expert_group,
topk_group,
topk,
)
else:
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
def select_experts(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment