Unverified Commit 66287582 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix batch invariant in torch 2.10 (#30907)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent eee600c3
...@@ -933,8 +933,6 @@ def enable_batch_invariant_mode(): ...@@ -933,8 +933,6 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE = True _batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB = torch.library.Library("aten", "IMPL")
# Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"):
if ( if (
current_platform.is_device_capability_family(100) current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(80) or current_platform.is_device_capability(80)
...@@ -949,9 +947,7 @@ def enable_batch_invariant_mode(): ...@@ -949,9 +947,7 @@ def enable_batch_invariant_mode():
else: else:
# Only source of batch invariance for Hopper is split-k, can disable through # Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config # cuBLAS workspace config
_original_cublas_workspace_cfg = os.environ.get( _original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
"CUBLAS_WORKSPACE_CONFIG", None
)
_original_cublaslt_workspace_size = os.environ.get( _original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None "CUBLASLT_WORKSPACE_SIZE", None
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment