Unverified Commit 5d9b5cbc authored by ptrblck's avatar ptrblck Committed by GitHub
Browse files

move sm80 code inside MHA (#937)


Co-authored-by: default avatarpbialecki <pbialecki@nvidia.com>
parent 85b17833
...@@ -278,12 +278,6 @@ torch_dir = torch.__path__[0] ...@@ -278,12 +278,6 @@ torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
generator_flag = ['-DOLD_GENERATOR'] generator_flag = ['-DOLD_GENERATOR']
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
if "--fast_multihead_attn" in sys.argv: if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
...@@ -295,6 +289,13 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -295,6 +289,13 @@ if "--fast_multihead_attn" in sys.argv:
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout', CUDAExtension(name='fast_additive_mask_softmax_dropout',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment