"...pytorch/graphsage/advanced/train_sampling_unsupervised.py" did not exist on "44089c8b4d4db4ca71e816e0de50dca972dbabdb"
Unverified Commit 38473363 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[DSv32] Use torch.compile for _get_logits_head_gate (#11565)

parent aaf7af1b
......@@ -205,6 +205,7 @@ class Indexer(CustomOp):
return ans
@torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment