Commit 91003340 authored by lcskrishna's avatar lcskrishna
Browse files

refactor based on latest hipify revamp

parent 539bad24
......@@ -10,11 +10,7 @@
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#if HIP_VERSION >= 310
#include "multi_tensor_apply_hip.cuh"
#else
#include "multi_tensor_apply.cuh"
#endif
#define BLOCK_SIZE 512
#define ILP 4
......
......@@ -8,11 +8,8 @@
#include <assert.h>
#include "type_shim.h"
#if HIP_VERSION >= 310
#include "multi_tensor_apply_hip.cuh"
#else
#include "multi_tensor_apply.cuh"
#endif
#define BLOCK_SIZE 512
#define ILP 4
......
......@@ -145,16 +145,9 @@ if "--cuda_ext" in sys.argv:
if not is_rocm_pytorch:
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
if is_rocm_pytorch:
import shutil
with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx:
hipify_python.hipify(project_directory=this_dir, output_directory=this_dir, includes="csrc/*",
show_detailed=True, is_pytorch_extension=True, clean_ctx=clean_ctx)
if torch.__version__ < '1.8':
shutil.copy("csrc/compat.h", "csrc/hip/compat.h")
shutil.copy("csrc/type_shim.h", "csrc/hip/type_shim.h")
if not is_rocm_pytorch:
print ("INFO: Building the multi-tensor apply extension.")
nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros
ext_modules.append(
CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp',
......@@ -168,93 +161,30 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_adagrad.cu',
'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lamb.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-lineinfo',
'-O3',
# '--resource-usage',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building Multitensor apply extension")
multi_tensor_sources_v1_8 = [
'csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_sgd_kernel.hip',
'csrc/multi_tensor_scale_kernel.hip',
'csrc/multi_tensor_axpby_kernel.hip',
'csrc/multi_tensor_l2norm_kernel.hip',
'csrc/multi_tensor_lamb_stage_1.hip',
'csrc/multi_tensor_lamb_stage_2.hip',
'csrc/multi_tensor_adam.hip',
'csrc/multi_tensor_adagrad.hip',
'csrc/multi_tensor_novograd.hip',
'csrc/multi_tensor_lamb.hip'
]
multi_tensor_sources_other = [
'csrc/amp_C_frontend.cpp',
'csrc/hip/multi_tensor_sgd_kernel.hip',
'csrc/hip/multi_tensor_scale_kernel.hip',
'csrc/hip/multi_tensor_axpby_kernel.hip',
'csrc/hip/multi_tensor_l2norm_kernel.hip',
'csrc/hip/multi_tensor_lamb_stage_1.hip',
'csrc/hip/multi_tensor_lamb_stage_2.hip',
'csrc/hip/multi_tensor_adam.hip',
'csrc/hip/multi_tensor_adagrad.hip',
'csrc/hip/multi_tensor_novograd.hip',
'csrc/hip/multi_tensor_lamb.hip',
]
ext_modules.append(
CUDAExtension(name='amp_C',
sources=multi_tensor_sources_v1_8 if torch.__version__ >= '1.8' else multi_tensor_sources_other,
extra_compile_args=['-O3'] + version_dependent_macros))
extra_compile_args = nvcc_args_multi_tensor if not is_rocm_pytorch else hipcc_args_multi_tensor))
if not is_rocm_pytorch:
print ("INFO: Builidng syncbn extension.")
ext_modules.append(
CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp',
'csrc/welford.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Building syncbn extension.")
syncbn_sources_v1_8 = ['csrc/syncbn.cpp', 'csrc/welford.hip']
syncbn_sources_other = ['csrc/syncbn.cpp', 'csrc/hip/welford.hip']
ext_modules.append(
CUDAExtension(name='syncbn',
sources=syncbn_sources_v1_8 if torch.__version__ >= '1.8' else syncbn_sources_other,
extra_compile_args=['-O3'] + version_dependent_macros))
extra_compile_args= ['-O3'] + version_dependent_macros))
if not is_rocm_pytorch:
nvcc_args_layer_norm = ['maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_layer_norm = ['-O3'] + version_dependent_macros
print ("INFO: Building fused layernorm extension.")
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp',
'csrc/layer_norm_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-maxrregcount=50',
'-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building FusedLayerNorm extension.")
layer_norm_sources_v1_8 = ['csrc/layer_norm_cuda.cpp', 'csrc/layer_norm_hip_kernel.hip']
layer_norm_sources_other = ['csrc/layer_norm_cuda.cpp', 'csrc/hip/layer_norm_hip_kernel.hip']
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources = layer_norm_sources_v1_8 if torch.__version__ >= '1.8' else layer_norm_sources_other,
extra_compile_args=['-O3'] + version_dependent_macros))
'nvcc': nvcc_args_layer_norm if not is_rocm_pytorch else hipcc_args_layer_norm}))
if not is_rocm_pytorch:
print ("INFO: Building the MLP Extension.")
ext_modules.append(
CUDAExtension(name='mlp_cuda',
sources=['csrc/mlp.cpp',
'csrc/mlp_cuda.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Building MLP extension")
mlp_sources_v1_8 = ['csrc/mlp.cpp', 'csrc/mlp_hip.hip']
mlp_sources_other = ['csrc/mlp.cpp', 'csrc/hip/mlp_hip.hip']
ext_modules.append(
CUDAExtension(name='mlp_cuda',
sources = mlp_sources_v1_8 if torch.__version__ >= '1.8' else mlp_sources_other,
extra_compile_args=['-O3'] + version_dependent_macros))
if "--bnp" in sys.argv:
......@@ -292,22 +222,12 @@ if "--xentropy" in sys.argv:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
if not is_rocm_pytorch:
print ("INFO: Building the xentropy extension.")
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
else:
xentropy_sources_v1_8 = ['apex/contrib/csrc/xentropy/interface.cpp', 'apex/contrib/csrc/xentropy/xentropy_kernel.hip']
xentropy_sources_other = ['apex/contrib/csrc/xentropy/interface.cpp', 'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip']
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources = xentropy_sources_v1_8 if torch.__version__ >= '1.8' else xentropy_sources_other,
include_dirs=[os.path.join(this_dir, 'csrc') if torch.__version__ >= '1.8' else os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
......@@ -323,29 +243,16 @@ if "--deprecated_fused_adam" in sys.argv:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
if not is_rocm_pytorch:
print ("INFO: Building deprecated fused adam extension.")
nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_fused_adam = ['-O3'] + version_dependent_macros
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building deprecated fused adam.")
fused_adam_sources_v1_8 = ['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_hip_kernel.hip']
fused_adam_sources_other = ['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip']
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
sources = fused_adam_sources_v1_8 if torch.__version__ >= '1.8' else fused_adam_sources_other,
include_dirs=[os.path.join(this_dir, 'csrc') if torch.__version__ >= '1.8' else os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc' : nvcc_args_fused_adam if not is_rocm_pytorch else hipcc_args_fused_adam}))
if "--deprecated_fused_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--deprecated_fused_lamb")
......@@ -358,29 +265,16 @@ if "--deprecated_fused_lamb" in sys.argv:
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
if not is_rocm_pytorch:
print ("INFO: Building deprecated fused lamb extension.")
nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros
hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros
ext_modules.append(
CUDAExtension(name='fused_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building deprecated fused lamb.")
fused_lamb_sources_v1_8 = ['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_lamb_hip_kernel.hip']
fused_lamb_sources_other = ['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip']
ext_modules.append(
CUDAExtension(name='fused_lamb_cuda',
sources = fused_lamb_sources_v1_8 if torch.__version__ >= '1.8' else fused_lamb_sources_other,
include_dirs = [os.path.join(this_dir, 'csrc') if torch.__version__ >= '1.8' else os.path.join(this_dir, 'csrc/hip')],
extra_compile_args=['-O3'] + version_dependent_macros))
extra_compile_args = nvcc_args_fused_lamb if not is_rocm_pytorch else hipcc_args_fused_lamb))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment