Unverified Commit 22fa63cf authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Bugfix][Torch 2.12] Fix batch_invariant test with allow_override for torch 2.12 upgrade (#40562)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent 8f87eb46
......@@ -963,8 +963,12 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
# torch 2.12+ registers a built-in Triton bmm kernel for CUDA
# (torch._native.ops.bmm_outer_product), so we need allow_override
# to replace it at the dispatcher level.
_batch_invariant_LIB.impl(
"aten::bmm", bmm_batch_invariant, "CUDA", allow_override=True
)
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment