Commit 2332c4d6 authored by Jeff Daily's avatar Jeff Daily
Browse files

update setup.py to more closely align with upstream

Mostly whitespace or formatting issues addressed.
Diff with upstream is reduced; ROCm changes are more clear.
parent dcc7b513
...@@ -114,6 +114,8 @@ def check_if_rocm_pytorch(): ...@@ -114,6 +114,8 @@ def check_if_rocm_pytorch():
return is_rocm_pytorch return is_rocm_pytorch
IS_ROCM_PYTORCH = check_if_rocm_pytorch()
# Set up macros for forward/backward compatibility hack around # Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and # and
...@@ -172,15 +174,10 @@ if "--cuda_ext" in sys.argv: ...@@ -172,15 +174,10 @@ if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
is_rocm_pytorch = False if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
if not is_rocm_pytorch: if not IS_ROCM_PYTORCH:
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
print ("INFO: Building the multi-tensor apply extension.") print ("INFO: Building the multi-tensor apply extension.")
...@@ -199,17 +196,18 @@ if "--cuda_ext" in sys.argv: ...@@ -199,17 +196,18 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_adagrad.cu', 'csrc/multi_tensor_adagrad.cu',
'csrc/multi_tensor_novograd.cu', 'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lamb.cu'], 'csrc/multi_tensor_lamb.cu'],
extra_compile_args = { 'cxx' : ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': nvcc_args_multi_tensor if not is_rocm_pytorch else hipcc_args_multi_tensor})) 'nvcc': nvcc_args_multi_tensor if not IS_ROCM_PYTORCH else hipcc_args_multi_tensor}))
print ("INFO: Building syncbn extension.") print ("INFO: Building syncbn extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='syncbn', CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp', sources=['csrc/syncbn.cpp',
'csrc/welford.cu'], 'csrc/welford.cu'],
extra_compile_args= ['-O3'] + version_dependent_macros)) extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
nvcc_args_layer_norm = ['maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros nvcc_args_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_layer_norm = ['-O3'] + version_dependent_macros hipcc_args_layer_norm = ['-O3'] + version_dependent_macros
print ("INFO: Building fused layernorm extension.") print ("INFO: Building fused layernorm extension.")
ext_modules.append( ext_modules.append(
...@@ -217,14 +215,15 @@ if "--cuda_ext" in sys.argv: ...@@ -217,14 +215,15 @@ if "--cuda_ext" in sys.argv:
sources=['csrc/layer_norm_cuda.cpp', sources=['csrc/layer_norm_cuda.cpp',
'csrc/layer_norm_cuda_kernel.cu'], 'csrc/layer_norm_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': nvcc_args_layer_norm if not is_rocm_pytorch else hipcc_args_layer_norm})) 'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm}))
print ("INFO: Building the MLP Extension.") print ("INFO: Building the MLP Extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='mlp_cuda', CUDAExtension(name='mlp_cuda',
sources=['csrc/mlp.cpp', sources=['csrc/mlp.cpp',
'csrc/mlp_cuda.cu'], 'csrc/mlp_cuda.cu'],
extra_compile_args=['-O3'] + version_dependent_macros)) extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
...@@ -256,9 +255,7 @@ if "--xentropy" in sys.argv: ...@@ -256,9 +255,7 @@ if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
is_rocm_pytorch = check_if_rocm_pytorch() if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
print ("INFO: Building the xentropy extension.") print ("INFO: Building the xentropy extension.")
...@@ -267,7 +264,8 @@ if "--xentropy" in sys.argv: ...@@ -267,7 +264,8 @@ if "--xentropy" in sys.argv:
sources=['apex/contrib/csrc/xentropy/interface.cpp', sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'], 'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args=['-O3'] + version_dependent_macros)) extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--deprecated_fused_adam" in sys.argv: if "--deprecated_fused_adam" in sys.argv:
...@@ -277,9 +275,7 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -277,9 +275,7 @@ if "--deprecated_fused_adam" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
is_rocm_pytorch = check_if_rocm_pytorch() if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
print ("INFO: Building deprecated fused adam extension.") print ("INFO: Building deprecated fused adam extension.")
...@@ -291,7 +287,8 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -291,7 +287,8 @@ if "--deprecated_fused_adam" in sys.argv:
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc' : nvcc_args_fused_adam if not is_rocm_pytorch else hipcc_args_fused_adam})) 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
if "--deprecated_fused_lamb" in sys.argv: if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_lamb") sys.argv.remove("--deprecated_fused_lamb")
...@@ -299,9 +296,7 @@ if "--deprecated_fused_lamb" in sys.argv: ...@@ -299,9 +296,7 @@ if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
is_rocm_pytorch = check_if_rocm_pytorch() if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
print ("INFO: Building deprecated fused lamb extension.") print ("INFO: Building deprecated fused lamb extension.")
...@@ -313,7 +308,7 @@ if "--deprecated_fused_lamb" in sys.argv: ...@@ -313,7 +308,7 @@ if "--deprecated_fused_lamb" in sys.argv:
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu', 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'], 'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args = nvcc_args_fused_lamb if not is_rocm_pytorch else hipcc_args_fused_lamb)) extra_compile_args = nvcc_args_fused_lamb if not IS_ROCM_PYTORCH else hipcc_args_fused_lamb))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 # Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag = [] generator_flag = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment