Commit 7aa701b1 authored by rusty1s's avatar rusty1s
Browse files

fix

parent 2149e811
...@@ -7,7 +7,6 @@ from sys import argv ...@@ -7,7 +7,6 @@ from sys import argv
import torch import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
# Windows users: Edit both of these to contain your VS include path, i.e. # Windows users: Edit both of these to contain your VS include path, i.e.
# cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include'] # cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include']
# nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include'] # nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include']
...@@ -32,10 +31,9 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} ...@@ -32,10 +31,9 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
ext_modules = [] ext_modules = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))] exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
ext_modules += [ ext_modules += [
CppExtension( CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'], extra_compile_args=cxx_extra_compile_args,
extra_compile_args=cxx_extra_compile_args, extra_link_args=cxx_extra_link_args) for ext in exts
extra_link_args=cxx_extra_link_args) for ext in exts
] ]
if CUDA_HOME is not None and '--cpu' not in argv: if CUDA_HOME is not None and '--cpu' not in argv:
...@@ -43,13 +41,13 @@ if CUDA_HOME is not None and '--cpu' not in argv: ...@@ -43,13 +41,13 @@ if CUDA_HOME is not None and '--cpu' not in argv:
ext_modules += [ ext_modules += [
CUDAExtension( CUDAExtension(
f'torch_scatter.{ext}_cuda', f'torch_scatter.{ext}_cuda',
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], [f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={
extra_compile_args={
'cxx': cxx_extra_compile_args, 'cxx': cxx_extra_compile_args,
'nvcc': nvcc_extra_compile_args, 'nvcc': nvcc_extra_compile_args,
}, }, extra_link_args=nvcc_extra_link_args) for ext in exts
extra_link_args=nvcc_extra_link_args) for ext in exts
] ]
if '--cpu' in argv:
argv.remove('--cpu')
__version__ = '1.5.0' __version__ = '1.5.0'
url = 'https://github.com/rusty1s/pytorch_scatter' url = 'https://github.com/rusty1s/pytorch_scatter'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment