Unverified Commit 8124df13 authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

Enable Apex on ROCm and support multi tensor support. (#1)

* Initial commit to hipify all cuda code

* enable multi_tensor_apply extension

* added generatedFileCleaner to handle nested hip files
parent 1f2aa915
...@@ -115,8 +115,13 @@ __device__ __forceinline__ T reduce_block_into_lanes ...@@ -115,8 +115,13 @@ __device__ __forceinline__ T reduce_block_into_lanes
// __SYNCWARP(); // __SYNCWARP();
#pragma unroll #pragma unroll
for(int i = 16; i >= lanes; i >>= 1) for(int i = 16; i >= lanes; i >>= 1) {
#ifdef __HIP_PLATFORM_HCC__
final = final + __shfl_down(0xffffffff, final, i);
#else
final = final + __shfl_down_sync(0xffffffff, final, i); final = final + __shfl_down_sync(0xffffffff, final, i);
#endif
}
} }
if(share_result) if(share_result)
...@@ -165,8 +170,13 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op ...@@ -165,8 +170,13 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op
// __SYNCWARP(); // __SYNCWARP();
#pragma unroll #pragma unroll
for(int i = 16; i >= lanes; i >>= 1) for(int i = 16; i >= lanes; i >>= 1) {
#ifdef __HIP_PLATFORM_HCC__
final = fmaxf(fabsf(final), fabsf(__shfl_down(0xffffffff, final, i)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
}
} }
if(share_result) if(share_result)
......
...@@ -6,6 +6,8 @@ import sys ...@@ -6,6 +6,8 @@ import sys
import warnings import warnings
import os import os
from torch.utils.hipify import hipify_python
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
...@@ -99,51 +101,93 @@ version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 ...@@ -99,51 +101,93 @@ version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--cuda_ext" in sys.argv: if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext") sys.argv.remove("--cuda_ext")
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) if not is_rocm_pytorch:
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
ext_modules.append(
CUDAExtension(name='amp_C', if is_rocm_pytorch:
sources=['csrc/amp_C_frontend.cpp', import shutil
'csrc/multi_tensor_sgd_kernel.cu', with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx:
'csrc/multi_tensor_scale_kernel.cu', hipify_python.hipify(project_directory=this_dir, output_directory=this_dir, includes="csrc/*",
'csrc/multi_tensor_axpby_kernel.cu', show_detailed=True, is_pytorch_extension=True, clean_ctx=clean_ctx)
'csrc/multi_tensor_l2norm_kernel.cu', shutil.copy("csrc/compat.h", "csrc/hip/compat.h")
'csrc/multi_tensor_lamb_stage_1.cu', shutil.copy("csrc/type_shim.h", "csrc/hip/type_shim.h")
'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu', if not is_rocm_pytorch:
'csrc/multi_tensor_novograd.cu', ext_modules.append(
'csrc/multi_tensor_lamb.cu'], CUDAExtension(name='amp_C',
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, sources=['csrc/amp_C_frontend.cpp',
'nvcc':['-lineinfo', 'csrc/multi_tensor_sgd_kernel.cu',
'-O3', 'csrc/multi_tensor_scale_kernel.cu',
# '--resource-usage', 'csrc/multi_tensor_axpby_kernel.cu',
'--use_fast_math'] + version_dependent_macros})) 'csrc/multi_tensor_l2norm_kernel.cu',
ext_modules.append( 'csrc/multi_tensor_lamb_stage_1.cu',
CUDAExtension(name='syncbn', 'csrc/multi_tensor_lamb_stage_2.cu',
sources=['csrc/syncbn.cpp', 'csrc/multi_tensor_adam.cu',
'csrc/welford.cu'], 'csrc/multi_tensor_novograd.cu',
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, 'csrc/multi_tensor_lamb.cu'],
'nvcc':['-O3'] + version_dependent_macros})) extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-lineinfo',
ext_modules.append( '-O3',
CUDAExtension(name='fused_layer_norm_cuda', # '--resource-usage',
sources=['csrc/layer_norm_cuda.cpp', '--use_fast_math'] + version_dependent_macros}))
'csrc/layer_norm_cuda_kernel.cu'], else:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, print ("INFO: Building Multitensor apply extension")
'nvcc':['-maxrregcount=50', ext_modules.append(
'-O3', CUDAExtension(name='amp_C',
'--use_fast_math'] + version_dependent_macros})) sources=['csrc/amp_C_frontend.cpp',
'csrc/hip/multi_tensor_sgd_kernel.hip',
ext_modules.append( 'csrc/hip/multi_tensor_scale_kernel.hip',
CUDAExtension(name='mlp_cuda', 'csrc/hip/multi_tensor_axpby_kernel.hip',
sources=['csrc/mlp.cpp', 'csrc/hip/multi_tensor_l2norm_kernel.hip',
'csrc/mlp_cuda.cu'], 'csrc/hip/multi_tensor_lamb_stage_1.hip',
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, 'csrc/hip/multi_tensor_lamb_stage_2.hip',
'nvcc':['-O3'] + version_dependent_macros})) 'csrc/hip/multi_tensor_adam.hip',
'csrc/hip/multi_tensor_novograd.hip',
'csrc/hip/multi_tensor_lamb.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc': []}))
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp',
'csrc/welford.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Skipping syncbn extension.")
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp',
'csrc/layer_norm_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-maxrregcount=50',
'-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Skipping FusedLayerNorm extension.")
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='mlp_cuda',
sources=['csrc/mlp.cpp',
'csrc/mlp_cuda.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Skipping MLP extension")
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment