import platform from setuptools import setup, find_packages from sys import argv import torch from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) extra_compile_args = [] if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): extra_compile_args += ['-DVERSION_GE_1_3'] ext_modules = [ CppExtension( 'torch_scatter.scatter_cpu', ['cpu/scatter.cpp'], extra_compile_args=extra_compile_args + ['-Wno-unused-variable'] if platform.system() != 'Windows' else []) ] cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} GPU = True for arg in argv: if arg == '--cpu': GPU = False argv.remove(arg) if CUDA_HOME is not None and GPU: ext_modules += [ CUDAExtension('torch_scatter.scatter_cuda', ['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'], extra_compile_args=extra_compile_args), CUDAExtension('torch_scatter.segment_cuda', ['cuda/segment.cpp', 'cuda/segment_kernel.cu'], extra_compile_args=extra_compile_args), ] __version__ = '1.4.0' url = 'https://github.com/rusty1s/pytorch_scatter' install_requires = [] setup_requires = ['pytest-runner'] tests_require = ['pytest', 'pytest-cov'] setup( name='torch_scatter', version=__version__, description='PyTorch Extension Library of Optimized Scatter Operations', author='Matthias Fey', author_email='matthias.fey@tu-dortmund.de', url=url, download_url='{}/archive/{}.tar.gz'.format(url, __version__), keywords=[ 'pytorch', 'scatter', ], install_requires=install_requires, setup_requires=setup_requires, tests_require=tests_require, ext_modules=ext_modules, cmdclass=cmdclass, packages=find_packages(), )