Unverified Commit 9ce0a10f authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

enable ninja (#1164)

- passing include directories to `CUDAExtension`'s `include_dirs` argument
- removing `-I/path/to/dir` arguments from `extra_compile_args`
parent 54b93919
import torch import torch
from torch.utils import cpp_extension from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
...@@ -32,7 +32,7 @@ if not torch.cuda.is_available(): ...@@ -32,7 +32,7 @@ if not torch.cuda.is_available():
'If you wish to cross-compile for a single specific architecture,\n' 'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11: if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else: else:
...@@ -70,11 +70,8 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -70,11 +70,8 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0: if TORCH_MAJOR == 0:
raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, " raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version__)) "found torch.__version__ = {}".format(torch.__version__))
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if "--cpp_ext" in sys.argv: if "--cpp_ext" in sys.argv:
from torch.utils.cpp_extension import CppExtension
sys.argv.remove("--cpp_ext") sys.argv.remove("--cpp_ext")
ext_modules.append( ext_modules.append(
CppExtension('apex_C', CppExtension('apex_C',
...@@ -124,13 +121,9 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -124,13 +121,9 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv: if "--distributed_adam" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_adam") sys.argv.remove("--distributed_adam")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
...@@ -143,13 +136,9 @@ if "--distributed_adam" in sys.argv: ...@@ -143,13 +136,9 @@ if "--distributed_adam" in sys.argv:
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
if "--distributed_lamb" in sys.argv: if "--distributed_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_lamb") sys.argv.remove("--distributed_lamb")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
...@@ -162,13 +151,12 @@ if "--distributed_lamb" in sys.argv: ...@@ -162,13 +151,12 @@ if "--distributed_lamb" in sys.argv:
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
if torch.utils.cpp_extension.CUDA_HOME is None: if CUDA_HOME is None:
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
ext_modules.append( ext_modules.append(
CUDAExtension(name='amp_C', CUDAExtension(name='amp_C',
...@@ -219,13 +207,9 @@ if "--cuda_ext" in sys.argv: ...@@ -219,13 +207,9 @@ if "--cuda_ext" in sys.argv:
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--bnp") sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
...@@ -242,13 +226,9 @@ if "--bnp" in sys.argv: ...@@ -242,13 +226,9 @@ if "--bnp" in sys.argv:
'-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros})) '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
if "--xentropy" in sys.argv: if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--xentropy") sys.argv.remove("--xentropy")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
...@@ -260,13 +240,9 @@ if "--xentropy" in sys.argv: ...@@ -260,13 +240,9 @@ if "--xentropy" in sys.argv:
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
if "--deprecated_fused_adam" in sys.argv: if "--deprecated_fused_adam" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_adam") sys.argv.remove("--deprecated_fused_adam")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
...@@ -279,13 +255,9 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -279,13 +255,9 @@ if "--deprecated_fused_adam" in sys.argv:
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
if "--deprecated_fused_lamb" in sys.argv: if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_lamb") sys.argv.remove("--deprecated_fused_lamb")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
...@@ -305,18 +277,14 @@ if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')) ...@@ -305,18 +277,14 @@ if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h'))
generator_flag = ['-DOLD_GENERATOR'] generator_flag = ['-DOLD_GENERATOR']
if "--fast_layer_norm" in sys.argv: if "--fast_layer_norm" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_layer_norm") sys.argv.remove("--fast_layer_norm")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
...@@ -332,23 +300,19 @@ if "--fast_layer_norm" in sys.argv: ...@@ -332,23 +300,19 @@ if "--fast_layer_norm" in sys.argv:
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-I./apex/contrib/csrc/layer_norm/',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")]))
if "--fmha" in sys.argv: if "--fmha" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fmha") sys.argv.remove("--fmha")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11: if int(bare_metal_major) < 11:
raise RuntimeError("--fmha only supported on SM80") raise RuntimeError("--fmha only supported on SM80")
...@@ -367,32 +331,26 @@ if "--fmha" in sys.argv: ...@@ -367,32 +331,26 @@ if "--fmha" in sys.argv:
'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu', 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu',
], ],
extra_compile_args={'cxx': ['-O3', extra_compile_args={'cxx': ['-O3',
'-I./apex/contrib/csrc/fmha/src',
] + version_dependent_macros + generator_flag, ] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_80,code=sm_80', '-gencode', 'arch=compute_80,code=sm_80',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-I./apex/contrib/csrc/',
'-I./apex/contrib/csrc/fmha/src',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")]))
if "--fast_multihead_attn" in sys.argv: if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn") sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
...@@ -405,12 +363,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -405,12 +363,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout', CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp', sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
...@@ -418,12 +376,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -418,12 +376,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask', CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
...@@ -431,12 +389,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -431,12 +389,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias', CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
...@@ -444,12 +402,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -444,12 +402,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
...@@ -457,12 +415,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -457,12 +415,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add', CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
...@@ -470,12 +428,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -470,12 +428,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn', CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
...@@ -483,12 +441,12 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -483,12 +441,12 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add', CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
...@@ -496,31 +454,26 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -496,31 +454,26 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3', 'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
if "--transducer" in sys.argv: if "--transducer" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--transducer") sys.argv.remove("--transducer")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--transducer was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
ext_modules.append( ext_modules.append(
CUDAExtension(name='transducer_joint_cuda', CUDAExtension(name='transducer_joint_cuda',
sources=['apex/contrib/csrc/transducer/transducer_joint.cpp', sources=['apex/contrib/csrc/transducer/transducer_joint.cpp',
'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'], 'apex/contrib/csrc/transducer/transducer_joint_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3', 'nvcc': ['-O3'] + version_dependent_macros},
'-I./apex/contrib/csrc/multihead_attn/'] + version_dependent_macros})) include_dirs=[os.path.join(this_dir, 'csrc'), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")]))
ext_modules.append( ext_modules.append(
CUDAExtension(name='transducer_loss_cuda', CUDAExtension(name='transducer_loss_cuda',
sources=['apex/contrib/csrc/transducer/transducer_loss.cpp', sources=['apex/contrib/csrc/transducer/transducer_loss.cpp',
...@@ -530,20 +483,16 @@ if "--transducer" in sys.argv: ...@@ -530,20 +483,16 @@ if "--transducer" in sys.argv:
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
if "--fast_bottleneck" in sys.argv: if "--fast_bottleneck" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_bottleneck") sys.argv.remove("--fast_bottleneck")
from torch.utils.cpp_extension import BuildExtension if CUDA_HOME is None:
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_bottleneck was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_bottleneck', CUDAExtension(name='fast_bottleneck',
sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'], sources=['apex/contrib/csrc/bottleneck/bottleneck.cpp'],
include_dirs=['apex/contrib/csrc/cudnn-frontend/include'], include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag})) extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))
setup( setup(
...@@ -560,6 +509,6 @@ setup( ...@@ -560,6 +509,6 @@ setup(
'apex.egg-info',)), 'apex.egg-info',)),
description='PyTorch Extensions written by NVIDIA', description='PyTorch Extensions written by NVIDIA',
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass={'build_ext': BuildExtension} if ext_modules else {},
extras_require=extras, extras_require=extras,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment