Unverified Commit feaedbb0 authored by Vedant's avatar Vedant Committed by GitHub
Browse files

fix: Improve CUDA version detection and error handling (#1599)

* fix: Improve CUDA version detection and error handling

* lint fix

* lint fix
parent b9827962
...@@ -21,26 +21,55 @@ def get_compute_capabilities() -> list[tuple[int, int]]: ...@@ -21,26 +21,55 @@ def get_compute_capabilities() -> list[tuple[int, int]]:
@lru_cache(None) @lru_cache(None)
def get_cuda_version_tuple() -> tuple[int, int]: def get_cuda_version_tuple() -> Optional[tuple[int, int]]:
if torch.version.cuda: """Get CUDA/HIP version as a tuple of (major, minor)."""
return tuple(map(int, torch.version.cuda.split(".")[0:2])) try:
elif torch.version.hip: if torch.version.cuda:
return tuple(map(int, torch.version.hip.split(".")[0:2])) version_str = torch.version.cuda
elif torch.version.hip:
version_str = torch.version.hip
else:
return None
return None parts = version_str.split(".")
if len(parts) >= 2:
return tuple(map(int, parts[:2]))
return None
except (AttributeError, ValueError, IndexError):
return None
def get_cuda_version_string() -> str: def get_cuda_version_string() -> Optional[str]:
major, minor = get_cuda_version_tuple() """Get CUDA/HIP version as a string."""
version_tuple = get_cuda_version_tuple()
if version_tuple is None:
return None
major, minor = version_tuple
return f"{major * 10 + minor}" return f"{major * 10 + minor}"
def get_cuda_specs() -> Optional[CUDASpecs]: def get_cuda_specs() -> Optional[CUDASpecs]:
"""Get CUDA/HIP specifications."""
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return None return None
return CUDASpecs( try:
highest_compute_capability=(get_compute_capabilities()[-1]), compute_capabilities = get_compute_capabilities()
cuda_version_string=(get_cuda_version_string()), if not compute_capabilities:
cuda_version_tuple=get_cuda_version_tuple(), return None
)
version_tuple = get_cuda_version_tuple()
if version_tuple is None:
return None
version_string = get_cuda_version_string()
if version_string is None:
return None
return CUDASpecs(
highest_compute_capability=compute_capabilities[-1],
cuda_version_string=version_string,
cuda_version_tuple=version_tuple,
)
except Exception:
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment