Unverified Commit 50338df6 authored by Deyu Fu's avatar Deyu Fu Committed by GitHub
Browse files

change include_dirs to abs path (#719)

parent 5b71d369
...@@ -6,6 +6,9 @@ import sys ...@@ -6,6 +6,9 @@ import sys
import warnings import warnings
import os import os
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
if not torch.cuda.is_available(): if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486 # https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...@@ -148,7 +151,7 @@ if "--bnp" in sys.argv: ...@@ -148,7 +151,7 @@ if "--bnp" in sys.argv:
'apex/contrib/csrc/groupbn/ipc.cu', 'apex/contrib/csrc/groupbn/ipc.cu',
'apex/contrib/csrc/groupbn/interface.cpp', 'apex/contrib/csrc/groupbn/interface.cpp',
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'], 'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
include_dirs=['csrc'], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': [] + version_dependent_macros, extra_compile_args={'cxx': [] + version_dependent_macros,
'nvcc':['-DCUDA_HAS_FP16=1', 'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__', '-D__CUDA_NO_HALF_OPERATORS__',
...@@ -169,7 +172,7 @@ if "--xentropy" in sys.argv: ...@@ -169,7 +172,7 @@ if "--xentropy" in sys.argv:
CUDAExtension(name='xentropy_cuda', CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp', sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'], 'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=['csrc'], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
...@@ -187,7 +190,7 @@ if "--deprecated_fused_adam" in sys.argv: ...@@ -187,7 +190,7 @@ if "--deprecated_fused_adam" in sys.argv:
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=['csrc'], include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
...@@ -206,7 +209,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -206,7 +209,7 @@ if "--fast_multihead_attn" in sys.argv:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc':['-O3',
...@@ -219,7 +222,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -219,7 +222,7 @@ if "--fast_multihead_attn" in sys.argv:
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add', CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'], 'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc':['-O3',
...@@ -232,7 +235,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -232,7 +235,7 @@ if "--fast_multihead_attn" in sys.argv:
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn', CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'], 'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc':['-O3',
...@@ -245,7 +248,7 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -245,7 +248,7 @@ if "--fast_multihead_attn" in sys.argv:
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add', CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'], 'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3', 'nvcc':['-O3',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment