Unverified Commit 719215bd authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Make index_mul_2d extension backward compatible for Atomic header include (#96)



* Make index_mul_2d extension backward compatible for Atomic header include

* Typo
Co-authored-by: default avatarJithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
parent 89f5722c
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Atomic.cuh>
#ifdef ATEN_ATOMIC_HEADER
#include <ATen/cuda/Atomic.cuh>
#else
#include <THC/THCAtomics.cuh>
#endif
__global__ void index_mul_2d_float_dim64(
......
......@@ -31,6 +31,9 @@ if os.path.exists(context_file):
found_Backward_Pass_Guard = True
break
found_aten_atomic_header = False
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "Atomic.cuh")):
found_aten_atomic_header = True
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
......@@ -358,6 +361,13 @@ if "--focal_loss" in sys.argv or "--cuda_ext" in sys.argv:
if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv:
if "--index_mul_2d" in sys.argv:
sys.argv.remove("--index_mul_2d")
args_index_mul_2d = ['-O3']
if not IS_ROCM_PYTORCH:
args_index_mul_2d += ['--use_fast_math', '--ftz=false']
if found_aten_atomic_header:
args_index_mul_2d += ['-DATEN_ATOMIC_HEADER']
ext_modules.append(
CUDAExtension(
name='fused_index_mul_2d',
......@@ -368,7 +378,7 @@ if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv:
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc':(['-O3', '--use_fast_math', '--ftz=false'] if not IS_ROCM_PYTORCH else ['-O3']) + version_dependent_macros,
'nvcc': args_index_mul_2d + version_dependent_macros,
},
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment