Commit 9b4c68c7 authored by lcskrishna's avatar lcskrishna
Browse files

updated hipify changes for apex contrib

parent ef209a74
......@@ -9,7 +9,12 @@
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#if HIP_VERSION >= 310
#include "multi_tensor_apply_hip.cuh"
#else
#include "multi_tensor_apply.cuh"
#endif
#define BLOCK_SIZE 512
#define ILP 4
......
......@@ -8,8 +8,11 @@
#include <assert.h>
#include "type_shim.h"
#if HIP_VERSION >= 310
#include "multi_tensor_apply_hip.cuh"
#else
#include "multi_tensor_apply.cuh"
#endif
#define BLOCK_SIZE 512
#define ILP 4
......
......@@ -333,12 +333,20 @@ if "--xentropy" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
xentropy_sources_v1_8 = ['apex/contrib/csrc/xentropy/interface.cpp', 'apex/contrib/csrc/xentropy/xentropy_kernel.hip']
xentropy_sources_other = ['apex/contrib/csrc/xentropy/interface.cpp', 'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip']
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
CUDAExtension(name='xentropy_cuda',
sources = xentropy_sources_v1_8 if torch.__version__ >= '1.8' else xentropy_sources_other,
include_dirs=[os.path.join(this_dir, 'csrc') if torch.__version__ >= '1.8' else os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
#ext_modules.append(
# CUDAExtension(name='xentropy_cuda',
# sources=['apex/contrib/csrc/xentropy/interface.cpp',
# 'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip'],
# include_dirs=[os.path.join(this_dir, 'csrc/hip')],
# extra_compile_args=['-O3'] + version_dependent_macros))
if "--deprecated_fused_adam" in sys.argv:
......@@ -364,12 +372,23 @@ if "--deprecated_fused_adam" in sys.argv:
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building deprecated fused adam.")
fused_adam_sources_v1_8 = ['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_hip_kernel.hip']
fused_adam_sources_other = ['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip']
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
sources = fused_adam_sources_v1_8 if torch.__version__ >= '1.8' else fused_adam_sources_other,
include_dirs=[os.path.join(this_dir, 'csrc') if torch.__version__ >= '1.8' else os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
#ext_modules.append(
# CUDAExtension(name='fused_adam_cuda',
# sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
# 'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip'],
# include_dirs=[os.path.join(this_dir, 'csrc/hip')],
# extra_compile_args=['-O3'] + version_dependent_macros))
if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......@@ -395,13 +414,24 @@ if "--deprecated_fused_lamb" in sys.argv:
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building deprecated fused lamb.")
fused_lamb_sources_v1_8 = ['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_hip_kernel.hip']
fused_lamb_sources_other = ['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip']
ext_modules.append(
CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip',
'csrc/hip/multi_tensor_l2norm_kernel.hip'],
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
sources = fused_lamb_sources_v1_8 if torch.__version__ >= '1.8' else fused_lamb_sources_other,
include_dirs = [os.path.join(this_dir, 'csrc') if torch.__version__ >= '1.8' else os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
#ext_modules.append(
# CUDAExtension(name='fused_lamb_cuda',
# sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
# 'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip',
# 'csrc/hip/multi_tensor_l2norm_kernel.hip'],
# include_dirs=[os.path.join(this_dir, 'csrc/hip')],
# extra_compile_args=['-O3'] + version_dependent_macros))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment