Unverified Commit 17fbbf91 authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

[contrib] Support optimizers on rocm. (#33)

* enable deprecated fused adam optimizer

* enable deprecated fused lamb

* reset the compiler arguments

* syntax error

* aligning the compiler arguments
parent d2f6d04a
......@@ -87,6 +87,14 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk).")
def check_if_rocm_pytorch():
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
return is_rocm_pytorch
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and
......@@ -279,17 +287,28 @@ if "--deprecated_fused_adam" in sys.argv:
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
is_rocm_pytorch = check_if_rocm_pytorch()
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building deprecated fused adam.")
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......@@ -298,18 +317,30 @@ if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
is_rocm_pytorch = check_if_rocm_pytorch()
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building deprecated fused lamb.")
ext_modules.append(
CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip',
'csrc/hip/multi_tensor_l2norm_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment