Unverified Commit 1e0f9bc6 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Enable all supported CUDA extensions using --cuda_ext flag (#59)

* Use --cuda_ext flag to build all supported extensions

* Don't remove --cuda_ext since it'll be needed to build other extensions

* Need to clear all cmdline args so setup.py doesn't complain
parent 541da7a0
...@@ -137,9 +137,10 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -137,9 +137,10 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5'] version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv: if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_adam") if "--distributed_adam" in sys.argv:
sys.argv.remove("--distributed_adam")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -158,9 +159,10 @@ if "--distributed_adam" in sys.argv: ...@@ -158,9 +159,10 @@ if "--distributed_adam" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam})) 'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
if "--distributed_lamb" in sys.argv: if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_lamb") if "--distributed_lamb" in sys.argv:
sys.argv.remove("--distributed_lamb")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -181,7 +183,6 @@ if "--distributed_lamb" in sys.argv: ...@@ -181,7 +183,6 @@ if "--distributed_lamb" in sys.argv:
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext")
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
...@@ -238,9 +239,10 @@ if "--cuda_ext" in sys.argv: ...@@ -238,9 +239,10 @@ if "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
if "--bnp" in sys.argv: if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--bnp") if "--bnp" in sys.argv:
sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -262,9 +264,10 @@ if "--bnp" in sys.argv: ...@@ -262,9 +264,10 @@ if "--bnp" in sys.argv:
'-D__CUDA_NO_HALF_CONVERSIONS__', '-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros})) '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
if "--xentropy" in sys.argv: if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--xentropy") if "--xentropy" in sys.argv:
sys.argv.remove("--xentropy")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -283,9 +286,10 @@ if "--xentropy" in sys.argv: ...@@ -283,9 +286,10 @@ if "--xentropy" in sys.argv:
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
if "--deprecated_fused_adam" in sys.argv: if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_adam") if "--deprecated_fused_adam" in sys.argv:
sys.argv.remove("--deprecated_fused_adam")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -305,9 +309,10 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -305,9 +309,10 @@ if "--deprecated_fused_adam" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam})) 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
if "--deprecated_fused_lamb" in sys.argv: if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_lamb") if "--deprecated_fused_lamb" in sys.argv:
sys.argv.remove("--deprecated_fused_lamb")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -365,9 +370,10 @@ if "--fast_layer_norm" in sys.argv: ...@@ -365,9 +370,10 @@ if "--fast_layer_norm" in sys.argv:
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}))
if "--fast_multihead_attn" in sys.argv: if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn") if "--fast_multihead_attn" in sys.argv:
sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False) cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
...@@ -465,6 +471,9 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -465,6 +471,9 @@ if "--fast_multihead_attn" in sys.argv:
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha})) 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha}))
if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext")
setup( setup(
name='apex', name='apex',
version='0.1', version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment