Unverified Commit 5b53121a authored by ptrblck's avatar ptrblck Committed by GitHub
Browse files

Add sm80 for CUDA >= 11 (#925)

parent 700d6825
import torch import torch
from torch.utils import cpp_extension
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
...@@ -9,6 +10,16 @@ import os ...@@ -9,6 +10,16 @@ import os
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
if not torch.cuda.is_available(): if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486 # https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...@@ -16,10 +27,15 @@ if not torch.cuda.is_available(): ...@@ -16,10 +27,15 @@ if not torch.cuda.is_available():
print('\nWarning: Torch did not find available GPUs on this system.\n', print('\nWarning: Torch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n' 'If your intention is to cross-compile, this is not an error.\n'
'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n' 'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
'Volta (compute capability 7.0), and Turing (compute capability 7.5).\n' 'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
'If you wish to cross-compile for a single specific architecture,\n' 'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
...@@ -64,13 +80,18 @@ if "--cpp_ext" in sys.argv: ...@@ -64,13 +80,18 @@ if "--cpp_ext" in sys.argv:
CppExtension('apex_C', CppExtension('apex_C',
['csrc/flatten_unflatten.cpp',])) ['csrc/flatten_unflatten.cpp',]))
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split() output = raw_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
release = output[release_idx].split(".") release = output[release_idx].split(".")
bare_metal_major = release[0] bare_metal_major = release[0]
bare_metal_minor = release[1][0] bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1] torch_binary_minor = torch.version.cuda.split(".")[1]
...@@ -85,6 +106,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): ...@@ -85,6 +106,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk).") "You can try commenting out this check (at your own risk).")
# Set up macros for forward/backward compatibility hack around # Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and # and
...@@ -256,6 +278,13 @@ torch_dir = torch.__path__[0] ...@@ -256,6 +278,13 @@ torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
generator_flag = ['-DOLD_GENERATOR'] generator_flag = ['-DOLD_GENERATOR']
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
if "--fast_multihead_attn" in sys.argv: if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn") sys.argv.remove("--fast_multihead_attn")
...@@ -279,7 +308,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -279,7 +308,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout', CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp', sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
...@@ -292,7 +321,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -292,7 +321,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask', CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
...@@ -305,7 +334,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -305,7 +334,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias', CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
...@@ -318,7 +347,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -318,7 +347,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
...@@ -331,7 +360,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -331,7 +360,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add', CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
...@@ -344,7 +373,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -344,7 +373,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn', CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
...@@ -357,7 +386,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -357,7 +386,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add', CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
...@@ -370,7 +399,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -370,7 +399,7 @@ if "--fast_multihead_attn" in sys.argv:
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
setup( setup(
name='apex', name='apex',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment