Unverified Commit 1ebe1d6d authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Optimize MoE topk with torch compile (#3236)

parent 7811bfda
...@@ -17,6 +17,8 @@ from typing import Callable, Optional ...@@ -17,6 +17,8 @@ from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.utils import get_compiler_backend
def fused_topk_native( def fused_topk_native(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -74,6 +76,7 @@ def fused_topk( ...@@ -74,6 +76,7 @@ def fused_topk(
# This is used by the Deepseek-V2 model # This is used by the Deepseek-V2 model
@torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk( def grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -108,6 +111,7 @@ def grouped_topk( ...@@ -108,6 +111,7 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def biased_grouped_topk( def biased_grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment