"vscode:/vscode.git/clone" did not exist on "4edbe0d534debd907f75068bb520a5b9d42a3790"
Commit 2332c4d6 authored by Jeff Daily's avatar Jeff Daily
Browse files

update setup.py to more closely align with upstream

Mostly whitespace or formatting issues addressed.
Diff with upstream is reduced; ROCm changes are more clear.
parent dcc7b513
...@@ -114,6 +114,8 @@ def check_if_rocm_pytorch(): ...@@ -114,6 +114,8 @@ def check_if_rocm_pytorch():
return is_rocm_pytorch return is_rocm_pytorch
IS_ROCM_PYTORCH = check_if_rocm_pytorch()
# Set up macros for forward/backward compatibility hack around # Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and # and
...@@ -172,59 +174,56 @@ if "--cuda_ext" in sys.argv: ...@@ -172,59 +174,56 @@ if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
is_rocm_pytorch = False if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
if not is_rocm_pytorch: if not IS_ROCM_PYTORCH:
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
print ("INFO: Building the multi-tensor apply extension.") print ("INFO: Building the multi-tensor apply extension.")
nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='amp_C', CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp', sources=['csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_sgd_kernel.cu', 'csrc/multi_tensor_sgd_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu', 'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu', 'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu', 'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_lamb_stage_1.cu', 'csrc/multi_tensor_lamb_stage_1.cu',
'csrc/multi_tensor_lamb_stage_2.cu', 'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu', 'csrc/multi_tensor_adam.cu',
'csrc/multi_tensor_adagrad.cu', 'csrc/multi_tensor_adagrad.cu',
'csrc/multi_tensor_novograd.cu', 'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lamb.cu'], 'csrc/multi_tensor_lamb.cu'],
extra_compile_args = { 'cxx' : ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': nvcc_args_multi_tensor if not is_rocm_pytorch else hipcc_args_multi_tensor})) 'nvcc': nvcc_args_multi_tensor if not IS_ROCM_PYTORCH else hipcc_args_multi_tensor}))
print ("INFO: Building syncbn extension.") print ("INFO: Building syncbn extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='syncbn', CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp', sources=['csrc/syncbn.cpp',
'csrc/welford.cu'], 'csrc/welford.cu'],
extra_compile_args= ['-O3'] + version_dependent_macros)) extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
nvcc_args_layer_norm = ['maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros nvcc_args_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_layer_norm = ['-O3'] + version_dependent_macros hipcc_args_layer_norm = ['-O3'] + version_dependent_macros
print ("INFO: Building fused layernorm extension.") print ("INFO: Building fused layernorm extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda', CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp', sources=['csrc/layer_norm_cuda.cpp',
'csrc/layer_norm_cuda_kernel.cu'], 'csrc/layer_norm_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': nvcc_args_layer_norm if not is_rocm_pytorch else hipcc_args_layer_norm})) 'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm}))
print ("INFO: Building the MLP Extension.") print ("INFO: Building the MLP Extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='mlp_cuda', CUDAExtension(name='mlp_cuda',
sources=['csrc/mlp.cpp', sources=['csrc/mlp.cpp',
'csrc/mlp_cuda.cu'], 'csrc/mlp_cuda.cu'],
extra_compile_args=['-O3'] + version_dependent_macros)) extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
...@@ -256,18 +255,17 @@ if "--xentropy" in sys.argv: ...@@ -256,18 +255,17 @@ if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
is_rocm_pytorch = check_if_rocm_pytorch() if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
print ("INFO: Building the xentropy extension.") print ("INFO: Building the xentropy extension.")
ext_modules.append( ext_modules.append(
CUDAExtension(name='xentropy_cuda', CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp', sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'], 'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args=['-O3'] + version_dependent_macros)) extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--deprecated_fused_adam" in sys.argv: if "--deprecated_fused_adam" in sys.argv:
...@@ -277,21 +275,20 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -277,21 +275,20 @@ if "--deprecated_fused_adam" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
is_rocm_pytorch = check_if_rocm_pytorch() if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
print ("INFO: Building deprecated fused adam extension.") print ("INFO: Building deprecated fused adam extension.")
nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_fused_adam = ['-O3'] + version_dependent_macros hipcc_args_fused_adam = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc' : nvcc_args_fused_adam if not is_rocm_pytorch else hipcc_args_fused_adam})) 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam}))
if "--deprecated_fused_lamb" in sys.argv: if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_lamb") sys.argv.remove("--deprecated_fused_lamb")
...@@ -299,21 +296,19 @@ if "--deprecated_fused_lamb" in sys.argv: ...@@ -299,21 +296,19 @@ if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
is_rocm_pytorch = check_if_rocm_pytorch() if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
print ("INFO: Building deprecated fused lamb extension.") print ("INFO: Building deprecated fused lamb extension.")
nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_lamb_cuda', CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp', sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu', 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'], 'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args = nvcc_args_fused_lamb if not is_rocm_pytorch else hipcc_args_fused_lamb)) extra_compile_args = nvcc_args_fused_lamb if not IS_ROCM_PYTORCH else hipcc_args_fused_lamb))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 # Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag = [] generator_flag = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment