Unverified Commit 9045cc1e authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[torch.compile bug] avoid biased_grouped_topk_impl func repeatedly triggering...

[torch.compile bug] avoid biased_grouped_topk_impl func repeatedly triggering `torch.compile` in forward pass (#8353)
parent 70e37b97
...@@ -5,4 +5,4 @@ Hardware Supports ...@@ -5,4 +5,4 @@ Hardware Supports
amd.md amd.md
nvidia_jetson.md nvidia_jetson.md
cpu.md cpu.md
\ No newline at end of file
...@@ -387,6 +387,7 @@ def grouped_topk_cpu( ...@@ -387,6 +387,7 @@ def grouped_topk_cpu(
) )
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def biased_grouped_topk_impl( def biased_grouped_topk_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -482,7 +483,6 @@ def biased_grouped_topk_gpu( ...@@ -482,7 +483,6 @@ def biased_grouped_topk_gpu(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
compiled: bool = not _is_npu,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
...@@ -535,14 +535,7 @@ def biased_grouped_topk_gpu( ...@@ -535,14 +535,7 @@ def biased_grouped_topk_gpu(
) )
return topk_weights, topk_ids return topk_weights, topk_ids
else: else:
biased_grouped_topk_fn = ( return biased_grouped_topk_impl(
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states, hidden_states,
gating_output, gating_output,
correction_bias, correction_bias,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment