Unverified Commit 38473363 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[DSv32] Use torch.compile for _get_logits_head_gate (#11565)

parent aaf7af1b
...@@ -205,6 +205,7 @@ class Indexer(CustomOp): ...@@ -205,6 +205,7 @@ class Indexer(CustomOp):
return ans return ans
@torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x) weights, _ = self.weights_proj(x)
weights = weights * self.n_heads**-0.5 weights = weights * self.n_heads**-0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment