Unverified Commit 719215bd authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Make index_mul_2d extension backward compatible for Atomic header include (#96)



* Make index_mul_2d extension backward compatible for Atomic header include

* Typo
Co-authored-by: default avatarJithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
parent 89f5722c
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Atomic.cuh> #ifdef ATEN_ATOMIC_HEADER
#include <ATen/cuda/Atomic.cuh>
#else
#include <THC/THCAtomics.cuh>
#endif
__global__ void index_mul_2d_float_dim64( __global__ void index_mul_2d_float_dim64(
......
...@@ -31,6 +31,9 @@ if os.path.exists(context_file): ...@@ -31,6 +31,9 @@ if os.path.exists(context_file):
found_Backward_Pass_Guard = True found_Backward_Pass_Guard = True
break break
found_aten_atomic_header = False
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "Atomic.cuh")):
found_aten_atomic_header = True
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
...@@ -358,6 +361,13 @@ if "--focal_loss" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -358,6 +361,13 @@ if "--focal_loss" in sys.argv or "--cuda_ext" in sys.argv:
if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv: if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv:
if "--index_mul_2d" in sys.argv: if "--index_mul_2d" in sys.argv:
sys.argv.remove("--index_mul_2d") sys.argv.remove("--index_mul_2d")
args_index_mul_2d = ['-O3']
if not IS_ROCM_PYTORCH:
args_index_mul_2d += ['--use_fast_math', '--ftz=false']
if found_aten_atomic_header:
args_index_mul_2d += ['-DATEN_ATOMIC_HEADER']
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name='fused_index_mul_2d', name='fused_index_mul_2d',
...@@ -368,7 +378,7 @@ if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -368,7 +378,7 @@ if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv:
include_dirs=[os.path.join(this_dir, 'csrc')], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={ extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros, 'cxx': ['-O3'] + version_dependent_macros,
'nvcc':(['-O3', '--use_fast_math', '--ftz=false'] if not IS_ROCM_PYTORCH else ['-O3']) + version_dependent_macros, 'nvcc': args_index_mul_2d + version_dependent_macros,
}, },
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment