Commit e3b49faf authored by rusty1s's avatar rusty1s
Browse files

enable O3 build

parent 9a651d91
...@@ -44,7 +44,7 @@ def get_extensions(): ...@@ -44,7 +44,7 @@ def get_extensions():
if sys.platform == 'win32': if sys.platform == 'win32':
define_macros += [('torchscatter_EXPORTS', None)] define_macros += [('torchscatter_EXPORTS', None)]
extra_compile_args = {'cxx': ['-O2']} extra_compile_args = {'cxx': ['-O3']}
if not os.name == 'nt': # Not on Windows: if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare'] extra_compile_args['cxx'] += ['-Wno-sign-compare']
extra_link_args = [] if WITH_SYMBOLS else ['-s'] extra_link_args = [] if WITH_SYMBOLS else ['-s']
...@@ -69,14 +69,14 @@ def get_extensions(): ...@@ -69,14 +69,14 @@ def get_extensions():
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
if torch.version.hip:
nvcc_flags += ['-O3'] nvcc_flags += ['-O3']
# USE_ROCM was added to later versons of rocm pytorch if torch.version.hip:
# define here to support older pytorch versions # USE_ROCM was added to later versions of PyTorch.
# Define here to support older PyTorch versions as well:
define_macros += [('USE_ROCM', None)] define_macros += [('USE_ROCM', None)]
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__'] undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else: else:
nvcc_flags += ['--expt-relaxed-constexpr', '-O2'] nvcc_flags += ['--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment