Unverified Commit 05abd126 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

misc: add compute capability in check_env (#965)

parent 5f6fa04a
......@@ -73,10 +73,26 @@ def _get_gpu_info():
Get information about available GPUs.
"""
devices = defaultdict(list)
capabilities = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
capability = torch.cuda.get_device_capability(k)
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
return {f"GPU {','.join(device_ids)}": name for name, device_ids in devices.items()}
gpu_info = {}
for name, device_ids in devices.items():
gpu_info[f"GPU {','.join(device_ids)}"] = name
if len(capabilities) == 1:
# All GPUs have the same compute capability
cap, gpu_ids = list(capabilities.items())[0]
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
else:
# GPUs have different compute capabilities
for cap, gpu_ids in capabilities.items():
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
return gpu_info
def _get_cuda_version_info():
......@@ -118,6 +134,7 @@ def _get_cuda_driver_version():
"""
Get CUDA driver version.
"""
versions = set()
try:
output = subprocess.check_output(
[
......@@ -126,7 +143,11 @@ def _get_cuda_driver_version():
"--format=csv,noheader,nounits",
]
)
return {"CUDA Driver Version": output.decode().strip()}
versions = set(output.decode().strip().split("\n"))
if len(versions) == 1:
return {"CUDA Driver Version": versions.pop()}
else:
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
except subprocess.SubprocessError:
return {"CUDA Driver Version": "Not Available"}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment