Unverified Commit 38076dea authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

apply fused moe gate in ds v3/r1 (#5371)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 5e0a9b09
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import math
import os import os
from typing import Callable, Optional from typing import Callable, Optional
...@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip ...@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
if _is_cuda:
from sgl_kernel import moe_fused_gate
expert_distribution_recorder = ExpertDistributionRecorder() expert_distribution_recorder = ExpertDistributionRecorder()
...@@ -209,6 +212,10 @@ def biased_grouped_topk_impl( ...@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
def biased_grouped_topk( def biased_grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -220,6 +227,20 @@ def biased_grouped_topk( ...@@ -220,6 +227,20 @@ def biased_grouped_topk(
compiled: bool = True, compiled: bool = True,
n_share_experts_fusion: int = 0, n_share_experts_fusion: int = 0,
): ):
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
if (
_is_cuda
and n_share_experts_fusion == 0
and is_power_of_two(correction_bias.shape[0])
):
return moe_fused_gate(
gating_output,
correction_bias,
num_expert_group,
topk_group,
topk,
)
else:
biased_grouped_topk_fn = ( biased_grouped_topk_fn = (
torch.compile( torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend() biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment