Unverified Commit 89f5722c authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Faster build (#95)

* Remove redundant import's and enable ninja for MHA extension

* Remove redundant CUDAExtension import's
parent 5acb8d00
...@@ -137,7 +137,7 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4: ...@@ -137,7 +137,7 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
"Apex requires Pytorch 0.4 or newer.\nThe latest stable release can be obtained from https://pytorch.org/" "Apex requires Pytorch 0.4 or newer.\nThe latest stable release can be obtained from https://pytorch.org/"
) )
cmdclass = {} # cmdclass = {}
ext_modules = [] ext_modules = []
extras = {} extras = {}
...@@ -146,7 +146,6 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -146,7 +146,6 @@ if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0: if TORCH_MAJOR == 0:
raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, " raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version__)) "found torch.__version__ = {}".format(torch.__version__))
cmdclass['build_ext'] = BuildExtension
if "--cpp_ext" in sys.argv: if "--cpp_ext" in sys.argv:
sys.argv.remove("--cpp_ext") sys.argv.remove("--cpp_ext")
ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"])) ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"]))
...@@ -168,13 +167,9 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -168,13 +167,9 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv: if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--distributed_adam" in sys.argv: if "--distributed_adam" in sys.argv:
sys.argv.remove("--distributed_adam") sys.argv.remove("--distributed_adam")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
...@@ -190,13 +185,9 @@ if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -190,13 +185,9 @@ if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv:
'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam})) 'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam}))
if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv: if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--distributed_lamb" in sys.argv: if "--distributed_lamb" in sys.argv:
sys.argv.remove("--distributed_lamb") sys.argv.remove("--distributed_lamb")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
...@@ -212,8 +203,6 @@ if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -212,8 +203,6 @@ if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv:
'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb})) 'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb}))
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
...@@ -311,13 +300,9 @@ if "--cuda_ext" in sys.argv: ...@@ -311,13 +300,9 @@ if "--cuda_ext" in sys.argv:
if "--bnp" in sys.argv or "--cuda_ext" in sys.argv: if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
sys.argv.remove("--bnp") sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
...@@ -336,13 +321,9 @@ if "--bnp" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -336,13 +321,9 @@ if "--bnp" in sys.argv or "--cuda_ext" in sys.argv:
'-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros})) '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros}))
if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv: if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--xentropy" in sys.argv: if "--xentropy" in sys.argv:
sys.argv.remove("--xentropy") sys.argv.remove("--xentropy")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
...@@ -393,13 +374,9 @@ if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -393,13 +374,9 @@ if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv:
) )
if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv: if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--deprecated_fused_adam" in sys.argv: if "--deprecated_fused_adam" in sys.argv:
sys.argv.remove("--deprecated_fused_adam") sys.argv.remove("--deprecated_fused_adam")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
...@@ -416,13 +393,9 @@ if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -416,13 +393,9 @@ if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv:
'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam})) 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv: if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--deprecated_fused_lamb" in sys.argv: if "--deprecated_fused_lamb" in sys.argv:
sys.argv.remove("--deprecated_fused_lamb") sys.argv.remove("--deprecated_fused_lamb")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
...@@ -511,13 +484,9 @@ if "--fmha" in sys.argv: ...@@ -511,13 +484,9 @@ if "--fmha" in sys.argv:
if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv: if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
if "--fast_multihead_attn" in sys.argv: if "--fast_multihead_attn" in sys.argv:
sys.argv.remove("--fast_multihead_attn") sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
...@@ -688,7 +657,6 @@ setup( ...@@ -688,7 +657,6 @@ setup(
), ),
description="PyTorch Extensions written by NVIDIA", description="PyTorch Extensions written by NVIDIA",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass={'build_ext': BuildExtension} if ext_modules else {},
#cmdclass={'build_ext': BuildExtension} if ext_modules else {},
extras_require=extras, extras_require=extras,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment