Unverified Commit 044c3159 authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Make torch compile configurable for biased_grouped_topk (#4749)

parent 4db29e82
...@@ -129,8 +129,7 @@ def grouped_topk( ...@@ -129,8 +129,7 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@torch.compile(dynamic=True, backend=get_compiler_backend()) def biased_grouped_topk_impl(
def biased_grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
correction_bias: torch.Tensor, correction_bias: torch.Tensor,
...@@ -171,6 +170,34 @@ def biased_grouped_topk( ...@@ -171,6 +170,34 @@ def biased_grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
):
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
)
def select_experts( def select_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment