Unverified Commit 9045cc1e authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[torch.compile bug] avoid biased_grouped_topk_impl func repeatedly triggering...

[torch.compile bug] avoid biased_grouped_topk_impl func repeatedly triggering `torch.compile` in forward pass (#8353)
parent 70e37b97
......@@ -5,4 +5,4 @@ Hardware Supports
amd.md
nvidia_jetson.md
cpu.md
\ No newline at end of file
cpu.md
......@@ -387,6 +387,7 @@ def grouped_topk_cpu(
)
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def biased_grouped_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
......@@ -482,7 +483,6 @@ def biased_grouped_topk_gpu(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = not _is_npu,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
......@@ -535,14 +535,7 @@ def biased_grouped_topk_gpu(
)
return topk_weights, topk_ids
else:
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
return biased_grouped_topk_impl(
hidden_states,
gating_output,
correction_bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment