Unverified Commit 267e696d authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

Fix compile args, adding version_dependent_macros. (#12)

parent b2da92fc
......@@ -101,7 +101,7 @@ version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext")
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
......@@ -155,8 +155,7 @@ if "--cuda_ext" in sys.argv:
'csrc/hip/multi_tensor_adagrad.hip',
'csrc/hip/multi_tensor_novograd.hip',
'csrc/hip/multi_tensor_lamb.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc': []}))
extra_compile_args=['-O3'] + version_dependent_macros))
if not is_rocm_pytorch:
ext_modules.append(
......@@ -168,7 +167,7 @@ if "--cuda_ext" in sys.argv:
else:
print ("INFO: Skipping syncbn extension.")
if not is_rocm_pytorch:
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
......@@ -277,7 +276,7 @@ if "--deprecated_fused_lamb" in sys.argv:
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment