Unverified Commit c001deba authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Make bmm batch invariant injection optional (#12118)

parent b4d2da10
......@@ -524,7 +524,9 @@ def is_batch_invariant_mode_enabled():
return _batch_invariant_MODE
def enable_batch_invariant_mode():
def enable_batch_invariant_mode(
enable_bmm: bool = True,
):
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_MODE:
return
......@@ -537,11 +539,13 @@ def enable_batch_invariant_mode():
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
)
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
if enable_bmm:
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
def disable_batch_invariant_mode():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment