"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "824fb65adb1e9f8a0938fa8cc0bac47c4e1539a1"
Unverified Commit 65fc1c31 authored by Xudong Zhang's avatar Xudong Zhang Committed by GitHub
Browse files

set default coompute capability according to cuda version (#773)

parent c393af6c
...@@ -47,12 +47,6 @@ for i in range(device_count): ...@@ -47,12 +47,6 @@ for i in range(device_count):
raise RuntimeError( raise RuntimeError(
"GPUs with compute capability less than 7.0 are not supported.") "GPUs with compute capability less than 7.0 are not supported.")
compute_capabilities.add(major * 10 + minor) compute_capabilities.add(major * 10 + minor)
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80, 86, 90}
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
# Validate the NVCC CUDA version. # Validate the NVCC CUDA version.
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
...@@ -65,6 +59,18 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): ...@@ -65,6 +59,18 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80}
if nvcc_cuda_version >= Version("11.1"):
compute_capabilities.add(86)
if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(90)
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
# Use NVCC threads to parallelize the build. # Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"): if nvcc_cuda_version >= Version("11.2"):
num_threads = min(os.cpu_count(), 8) num_threads = min(os.cpu_count(), 8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment