Commit d1c42fbb authored by rusty1s's avatar rusty1s
Browse files

windows support

parent c182d679
import platform
from setuptools import setup, find_packages from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
...@@ -11,12 +12,15 @@ ext_modules = [] ...@@ -11,12 +12,15 @@ ext_modules = []
cmdclass = {} cmdclass = {}
if CUDA_HOME is not None: if CUDA_HOME is not None:
if platform.system() == 'Windows':
extra_link_args = ['cusparse.lib'],
else:
extra_link_args = ['-lcusparse', '-l', 'cusparse'],
ext_modules += [ ext_modules += [
CUDAExtension( CUDAExtension(
'spspmm_cuda', 'spspmm_cuda', ['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'], extra_link_args=extra_link_args),
extra_link_args=['-lcusparse', '-l', 'cusparse'],
),
CUDAExtension('unique_cuda', CUDAExtension('unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu']), ['cuda/unique.cpp', 'cuda/unique_kernel.cu']),
] ]
...@@ -37,5 +41,4 @@ setup( ...@@ -37,5 +41,4 @@ setup(
tests_require=tests_require, tests_require=tests_require,
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass=cmdclass,
packages=find_packages(), packages=find_packages(), )
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment