Unverified Commit be33bea4 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Add compute capability 8.0 if on cuda 11+ (#572)

parent 2d1f7c01
...@@ -5,6 +5,7 @@ from .sparse_attn import SparseAttnBuilder ...@@ -5,6 +5,7 @@ from .sparse_attn import SparseAttnBuilder
from .transformer import TransformerBuilder from .transformer import TransformerBuilder
from .stochastic_transformer import StochasticTransformerBuilder from .stochastic_transformer import StochasticTransformerBuilder
from .utils import UtilsBuilder from .utils import UtilsBuilder
from .builder import get_default_compute_capatabilities
# TODO: infer this list instead of hard coded # TODO: infer this list instead of hard coded
# List of all available ops # List of all available ops
......
...@@ -11,9 +11,10 @@ END = '\033[0m' ...@@ -11,9 +11,10 @@ END = '\033[0m'
WARNING = f"{YELLOW} [WARNING] {END}" WARNING = f"{YELLOW} [WARNING] {END}"
DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
def assert_no_cuda_mismatch(): def installed_cuda_version():
import torch.utils.cpp_extension import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME cuda_home = torch.utils.cpp_extension.CUDA_HOME
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
...@@ -25,12 +26,26 @@ def assert_no_cuda_mismatch(): ...@@ -25,12 +26,26 @@ def assert_no_cuda_mismatch():
release_idx = output_split.index("release") release_idx = output_split.index("release")
release = output_split[release_idx + 1].replace(',', '').split(".") release = output_split[release_idx + 1].replace(',', '').split(".")
# Ignore patch versions, only look at major + minor # Ignore patch versions, only look at major + minor
cuda_major, cuda_minor = release[:2]
installed_cuda_version = ".".join(release[:2]) installed_cuda_version = ".".join(release[:2])
return int(cuda_major), int(cuda_minor)
def get_default_compute_capatabilities():
compute_caps = DEFAULT_COMPUTE_CAPABILITIES
if installed_cuda_version()[0] >= 11:
compute_caps += ";8.0"
return compute_caps
def assert_no_cuda_mismatch():
cuda_major, cuda_minor = installed_cuda_version()
sys_cuda_version = f'{cuda_major}.{cuda_minor}'
torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
# This is a show-stopping error, should probably not proceed past this # This is a show-stopping error, should probably not proceed past this
if installed_cuda_version != torch_cuda_version: if sys_cuda_version != torch_cuda_version:
raise Exception( raise Exception(
f"Installed CUDA version {installed_cuda_version} does not match the " f"Installed CUDA version {sys_cuda_version} does not match the "
f"version torch was compiled with {torch.version.cuda}, unable to compile " f"version torch was compiled with {torch.version.cuda}, unable to compile "
"cuda/cpp extensions without a matching cuda version.") "cuda/cpp extensions without a matching cuda version.")
...@@ -197,7 +212,10 @@ class OpBuilder(ABC): ...@@ -197,7 +212,10 @@ class OpBuilder(ABC):
class CUDAOpBuilder(OpBuilder): class CUDAOpBuilder(OpBuilder):
def compute_capability_args(self, cross_compile_archs=['60', '61', '70']): def compute_capability_args(self, cross_compile_archs=None):
if cross_compile_archs is None:
cross_compile_archs = get_default_compute_capatabilities()
args = [] args = []
if self.jit_mode: if self.jit_mode:
# Compile for underlying architecture since we know it at runtime # Compile for underlying architecture since we know it at runtime
...@@ -208,7 +226,8 @@ class CUDAOpBuilder(OpBuilder): ...@@ -208,7 +226,8 @@ class CUDAOpBuilder(OpBuilder):
f'arch=compute_{compute_capability},code=compute_{compute_capability}') f'arch=compute_{compute_capability},code=compute_{compute_capability}')
else: else:
# Cross-compile mode, compile for various architectures # Cross-compile mode, compile for various architectures
for compute_capability in cross_compile_archs: for compute_capability in cross_compile_archs.split(';'):
compute_capability = compute_capability.replace('.', '')
args.append('-gencode') args.append('-gencode')
args.append( args.append(
f'arch=compute_{compute_capability},code=compute_{compute_capability}' f'arch=compute_{compute_capability},code=compute_{compute_capability}'
......
...@@ -21,7 +21,7 @@ except ImportError: ...@@ -21,7 +21,7 @@ except ImportError:
raise ImportError('Unable to import torch, please visit https://pytorch.org/ ' raise ImportError('Unable to import torch, please visit https://pytorch.org/ '
'to see how to properly install torch on your system.') 'to see how to properly install torch on your system.')
import op_builder from op_builder import ALL_OPS, get_default_compute_capatabilities
def fetch_requirements(path): def fetch_requirements(path):
...@@ -64,12 +64,10 @@ if not torch.cuda.is_available(): ...@@ -64,12 +64,10 @@ if not torch.cuda.is_available():
"you can ignore this message. Adding compute capability for Pascal, Volta, and Turing " "you can ignore this message. Adding compute capability for Pascal, Volta, and Turing "
"(compute capabilities 6.0, 6.1, 6.2)") "(compute capabilities 6.0, 6.1, 6.2)")
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" os.environ["TORCH_CUDA_ARCH_LIST"] = get_default_compute_capatabilities()
ext_modules = [] ext_modules = []
from op_builder import ALL_OPS
# Default to pre-install kernels to false so we rely on JIT # Default to pre-install kernels to false so we rely on JIT
BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', 0)) BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', 0))
print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}") print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment