Unverified Commit be33bea4 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Add compute capability 8.0 if on cuda 11+ (#572)

parent 2d1f7c01
......@@ -5,6 +5,7 @@ from .sparse_attn import SparseAttnBuilder
from .transformer import TransformerBuilder
from .stochastic_transformer import StochasticTransformerBuilder
from .utils import UtilsBuilder
from .builder import get_default_compute_capatabilities
# TODO: infer this list instead of hard coded
# List of all available ops
......
......@@ -11,9 +11,10 @@ END = '\033[0m'
WARNING = f"{YELLOW} [WARNING] {END}"
DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
def assert_no_cuda_mismatch():
def installed_cuda_version():
import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
......@@ -25,12 +26,26 @@ def assert_no_cuda_mismatch():
release_idx = output_split.index("release")
release = output_split[release_idx + 1].replace(',', '').split(".")
# Ignore patch versions, only look at major + minor
cuda_major, cuda_minor = release[:2]
installed_cuda_version = ".".join(release[:2])
return int(cuda_major), int(cuda_minor)
def get_default_compute_capatabilities():
compute_caps = DEFAULT_COMPUTE_CAPABILITIES
if installed_cuda_version()[0] >= 11:
compute_caps += ";8.0"
return compute_caps
def assert_no_cuda_mismatch():
cuda_major, cuda_minor = installed_cuda_version()
sys_cuda_version = f'{cuda_major}.{cuda_minor}'
torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
# This is a show-stopping error, should probably not proceed past this
if installed_cuda_version != torch_cuda_version:
if sys_cuda_version != torch_cuda_version:
raise Exception(
f"Installed CUDA version {installed_cuda_version} does not match the "
f"Installed CUDA version {sys_cuda_version} does not match the "
f"version torch was compiled with {torch.version.cuda}, unable to compile "
"cuda/cpp extensions without a matching cuda version.")
......@@ -197,7 +212,10 @@ class OpBuilder(ABC):
class CUDAOpBuilder(OpBuilder):
def compute_capability_args(self, cross_compile_archs=['60', '61', '70']):
def compute_capability_args(self, cross_compile_archs=None):
if cross_compile_archs is None:
cross_compile_archs = get_default_compute_capatabilities()
args = []
if self.jit_mode:
# Compile for underlying architecture since we know it at runtime
......@@ -208,7 +226,8 @@ class CUDAOpBuilder(OpBuilder):
f'arch=compute_{compute_capability},code=compute_{compute_capability}')
else:
# Cross-compile mode, compile for various architectures
for compute_capability in cross_compile_archs:
for compute_capability in cross_compile_archs.split(';'):
compute_capability = compute_capability.replace('.', '')
args.append('-gencode')
args.append(
f'arch=compute_{compute_capability},code=compute_{compute_capability}'
......
......@@ -21,7 +21,7 @@ except ImportError:
raise ImportError('Unable to import torch, please visit https://pytorch.org/ '
'to see how to properly install torch on your system.')
import op_builder
from op_builder import ALL_OPS, get_default_compute_capatabilities
def fetch_requirements(path):
......@@ -64,12 +64,10 @@ if not torch.cuda.is_available():
"you can ignore this message. Adding compute capability for Pascal, Volta, and Turing "
"(compute capabilities 6.0, 6.1, 6.2)")
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
os.environ["TORCH_CUDA_ARCH_LIST"] = get_default_compute_capatabilities()
ext_modules = []
from op_builder import ALL_OPS
# Default to pre-install kernels to false so we rely on JIT
BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', 0))
print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment