Commit ef209a74 authored by lcskrishna's avatar lcskrishna
Browse files

update setup file for rocm due to newer hipify changes

parent 7eed38aa
......@@ -150,6 +150,7 @@ if "--cuda_ext" in sys.argv:
with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx:
hipify_python.hipify(project_directory=this_dir, output_directory=this_dir, includes="csrc/*",
show_detailed=True, is_pytorch_extension=True, clean_ctx=clean_ctx)
if torch.__version__ < '1.8':
shutil.copy("csrc/compat.h", "csrc/hip/compat.h")
shutil.copy("csrc/type_shim.h", "csrc/hip/type_shim.h")
......@@ -174,9 +175,22 @@ if "--cuda_ext" in sys.argv:
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building Multitensor apply extension")
ext_modules.append(
CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp',
multi_tensor_sources_v1_8 = [
'csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_sgd_kernel.hip',
'csrc/multi_tensor_scale_kernel.hip',
'csrc/multi_tensor_axpby_kernel.hip',
'csrc/multi_tensor_l2norm_kernel.hip',
'csrc/multi_tensor_lamb_stage_1.hip',
'csrc/multi_tensor_lamb_stage_2.hip',
'csrc/multi_tensor_adam.hip',
'csrc/multi_tensor_adagrad.hip',
'csrc/multi_tensor_novograd.hip',
'csrc/multi_tensor_lamb.hip'
]
multi_tensor_sources_other = [
'csrc/amp_C_frontend.cpp',
'csrc/hip/multi_tensor_sgd_kernel.hip',
'csrc/hip/multi_tensor_scale_kernel.hip',
'csrc/hip/multi_tensor_axpby_kernel.hip',
......@@ -186,7 +200,25 @@ if "--cuda_ext" in sys.argv:
'csrc/hip/multi_tensor_adam.hip',
'csrc/hip/multi_tensor_adagrad.hip',
'csrc/hip/multi_tensor_novograd.hip',
'csrc/hip/multi_tensor_lamb.hip'],
'csrc/hip/multi_tensor_lamb.hip',
]
#ext_modules.append(
# CUDAExtension(name='amp_C',
# sources=['csrc/amp_C_frontend.cpp',
# 'csrc/hip/multi_tensor_sgd_kernel.hip',
# 'csrc/hip/multi_tensor_scale_kernel.hip',
# 'csrc/hip/multi_tensor_axpby_kernel.hip',
# 'csrc/hip/multi_tensor_l2norm_kernel.hip',
# 'csrc/hip/multi_tensor_lamb_stage_1.hip',
# 'csrc/hip/multi_tensor_lamb_stage_2.hip',
# 'csrc/hip/multi_tensor_adam.hip',
# 'csrc/hip/multi_tensor_adagrad.hip',
# 'csrc/hip/multi_tensor_novograd.hip',
# 'csrc/hip/multi_tensor_lamb.hip'],
# extra_compile_args=['-O3'] + version_dependent_macros))
ext_modules.append(
CUDAExtension(name='amp_C',
sources=multi_tensor_sources_v1_8 if torch.__version__ >= '1.8' else multi_tensor_sources_other,
extra_compile_args=['-O3'] + version_dependent_macros))
if not is_rocm_pytorch:
......@@ -198,11 +230,17 @@ if "--cuda_ext" in sys.argv:
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Building syncbn extension.")
syncbn_sources_v1_8 = ['csrc/syncbn.cpp', 'csrc/welford.hip']
syncbn_sources_other = ['csrc/syncbn.cpp', 'csrc/hip/welford.hip']
ext_modules.append(
CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp',
'csrc/hip/welford.hip'],
sources=syncbn_sources_v1_8 if torch.__version__ >= '1.8' else syncbn_sources_other,
extra_compile_args=['-O3'] + version_dependent_macros))
#ext_modules.append(
# CUDAExtension(name='syncbn',
# sources=['csrc/syncbn.cpp',
# 'csrc/hip/welford.hip'],
# extra_compile_args=['-O3'] + version_dependent_macros))
if not is_rocm_pytorch:
......@@ -216,12 +254,18 @@ if "--cuda_ext" in sys.argv:
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Building FusedLayerNorm extension.")
layer_norm_sources_v1_8 = ['csrc/layer_norm_cuda.cpp', 'csrc/layer_norm_hip_kernel.hip']
layer_norm_sources_other = ['csrc/layer_norm_cuda.cpp', 'csrc/hip/layer_norm_hip_kernel.hip']
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp',
'csrc/hip/layer_norm_hip_kernel.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc' : []}))
sources = layer_norm_sources_v1_8 if torch.__version__ >= '1.8' else layer_norm_sources_other,
extra_compile_args=['-O3'] + version_dependent_macros))
#ext_modules.append(
# CUDAExtension(name='fused_layer_norm_cuda',
# sources=['csrc/layer_norm_cuda.cpp',
# 'csrc/hip/layer_norm_hip_kernel.hip'],
# extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
# 'nvcc' : []}))
if not is_rocm_pytorch:
ext_modules.append(
......@@ -232,12 +276,18 @@ if "--cuda_ext" in sys.argv:
'nvcc':['-O3'] + version_dependent_macros}))
else:
print ("INFO: Building MLP extension")
mlp_sources_v1_8 = ['csrc/mlp.cpp', 'csrc/mlp_hip.hip']
mlp_sources_other = ['csrc/mlp.cpp', 'csrc/hip/mlp_hip.hip']
ext_modules.append(
CUDAExtension(name='mlp_cuda',
sources=['csrc/mlp.cpp',
'csrc/hip/mlp_hip.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc' : []}))
sources = mlp_sources_v1_8 if torch.__version__ >= '1.8' else mlp_sources_other,
extra_compile_args=['-O3'] + version_dependent_macros))
#ext_modules.append(
# CUDAExtension(name='mlp_cuda',
# sources=['csrc/mlp.cpp',
# 'csrc/hip/mlp_hip.hip'],
# extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
# 'nvcc' : []}))
if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment