Unverified Commit 16a65e41 authored by Yusuf Mohammad's avatar Yusuf Mohammad Committed by GitHub
Browse files

[Bugfix] Enable batch-invariant Triton matmul on all Ampere GPUs (SM 8x) (#38427)


Signed-off-by: default avataryusuf <yusufmohammad@live.com>
Signed-off-by: default avataryusuf <yusuf@deeplearningmachine.mynet>
Signed-off-by: default avatarYusuf Mohammad <79484377+YM2132@users.noreply.github.com>
Signed-off-by: <>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avataryusuf <yusuf@deeplearningmachine.mynet>
parent c0817e4d
...@@ -935,11 +935,9 @@ def enable_batch_invariant_mode(): ...@@ -935,11 +935,9 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE = True _batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB = torch.library.Library("aten", "IMPL")
if ( if current_platform.is_device_capability_family(
current_platform.is_device_capability_family(100) 100
or current_platform.is_device_capability(80) ) or current_platform.is_device_capability_family(80):
or current_platform.is_device_capability(89)
):
# For PyTorch 2.9, B200 uses GEMV for bs=1 # For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735 # Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment