Unverified Commit 89f5722c authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Faster build (#95)

* Remove redundant import's and enable ninja for MHA extension

* Remove redundant CUDAExtension import's
parent 5acb8d00
......@@ -137,7 +137,7 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
"Apex requires Pytorch 0.4 or newer.\nThe latest stable release can be obtained from https://pytorch.org/"
)
cmdclass = {}
# cmdclass = {}
ext_modules = []
extras = {}
......@@ -146,7 +146,6 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0:
raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version__))
cmdclass['build_ext'] = BuildExtension
if "--cpp_ext" in sys.argv:
sys.argv.remove("--cpp_ext")
ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"]))
......@@ -168,13 +167,9 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--distributed_adam" in sys.argv:
sys.argv.remove("--distributed_adam")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
......@@ -190,13 +185,9 @@ if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--distributed_lamb" in sys.argv:
sys.argv.remove("--distributed_lamb")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
......@@ -212,8 +203,6 @@ if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv:
'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb}))
if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
......@@ -311,13 +300,9 @@ if "--cuda_ext" in sys.argv:
if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--bnp" in sys.argv:
sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
......@@ -336,13 +321,9 @@ if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
'-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--xentropy" in sys.argv:
sys.argv.remove("--xentropy")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
......@@ -393,13 +374,9 @@ if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv:
)
if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--deprecated_fused_adam" in sys.argv:
sys.argv.remove("--deprecated_fused_adam")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
......@@ -416,13 +393,9 @@ if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--deprecated_fused_lamb" in sys.argv:
sys.argv.remove("--deprecated_fused_lamb")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
......@@ -511,13 +484,9 @@ if "--fmha" in sys.argv:
if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--fast_multihead_attn" in sys.argv:
sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
......@@ -688,7 +657,6 @@ setup(
),
description="PyTorch Extensions written by NVIDIA",
ext_modules=ext_modules,
cmdclass=cmdclass,
#cmdclass={'build_ext': BuildExtension} if ext_modules else {},
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
extras_require=extras,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment