Unverified Commit a41c2043 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add compute capability 8.9 to default targets (#829)

parent eedac9db
...@@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] ...@@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if CUDA_HOME is None: if CUDA_HOME is None:
raise RuntimeError( raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available in order to build the package.") f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_nvcc_cuda_version(cuda_dir: str) -> Version:
...@@ -55,6 +55,14 @@ if nvcc_cuda_version < Version("11.0"): ...@@ -55,6 +55,14 @@ if nvcc_cuda_version < Version("11.0"):
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") "CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
compute_capabilities.remove(89)
compute_capabilities.add(80)
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
...@@ -65,6 +73,7 @@ if not compute_capabilities: ...@@ -65,6 +73,7 @@ if not compute_capabilities:
if nvcc_cuda_version >= Version("11.1"): if nvcc_cuda_version >= Version("11.1"):
compute_capabilities.add(86) compute_capabilities.add(86)
if nvcc_cuda_version >= Version("11.8"): if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(89)
compute_capabilities.add(90) compute_capabilities.add(90)
# Add target compute capabilities to NVCC flags. # Add target compute capabilities to NVCC flags.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment