Commit 684c064c authored by Soumith Chintala's avatar Soumith Chintala
Browse files

optional nvcc flags

parent 5f417b5d
......@@ -88,10 +88,20 @@ def get_extensions():
define_macros = []
extra_compile_args = {}
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
if nvcc_flags == '':
nvcc_flags = []
else:
nvcc_flags = nvcc_flags.split(' ')
extra_compile_args = {
'cxx': ['-O0'],
'nvcc': nvcc_flags,
}
sources = [os.path.join(extensions_dir, s) for s in sources]
......@@ -103,6 +113,7 @@ def get_extensions():
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment