Commit 75a3899f authored by rusty1s's avatar rusty1s
Browse files

year up, restricted coverage, nested extensions

parent da8f675e
[run]
source=torch_sparse
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover
......
Copyright (c) 2018 Matthias Fey <matthias.fey@tu-dortmund.de> Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
......
#include <torch/torch.h> #include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
......
#include <torch/torch.h> #include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
......
...@@ -13,15 +13,16 @@ cmdclass = {} ...@@ -13,15 +13,16 @@ cmdclass = {}
if CUDA_HOME is not None: if CUDA_HOME is not None:
if platform.system() == 'Windows': if platform.system() == 'Windows':
extra_link_args = 'cusparse.lib' extra_link_args = ['cusparse.lib']
else: else:
extra_link_args = ['-lcusparse', '-l', 'cusparse'] extra_link_args = ['-lcusparse', '-l', 'cusparse']
ext_modules += [ ext_modules += [
CUDAExtension( CUDAExtension(
'spspmm_cuda', ['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'], 'torch_sparse.spspmm_cuda',
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
extra_link_args=extra_link_args), extra_link_args=extra_link_args),
CUDAExtension('unique_cuda', CUDAExtension('torch_sparse.unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu']), ['cuda/unique.cpp', 'cuda/unique_kernel.cu']),
] ]
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
...@@ -29,8 +30,8 @@ if CUDA_HOME is not None: ...@@ -29,8 +30,8 @@ if CUDA_HOME is not None:
setup( setup(
name='torch_sparse', name='torch_sparse',
version=__version__, version=__version__,
description='PyTorch Extension Library of Optimized Autograd Sparse ' description=('PyTorch Extension Library of Optimized Autograd Sparse '
'Matrix Operations', 'Matrix Operations'),
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url=url, url=url,
...@@ -41,4 +42,5 @@ setup( ...@@ -41,4 +42,5 @@ setup(
tests_require=tests_require, tests_require=tests_require,
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass=cmdclass,
packages=find_packages(), ) packages=find_packages(),
)
...@@ -4,7 +4,7 @@ import scipy.sparse ...@@ -4,7 +4,7 @@ import scipy.sparse
from torch_sparse import transpose from torch_sparse import transpose
if torch.cuda.is_available(): if torch.cuda.is_available():
import spspmm_cuda import torch_sparse.spspmm_cuda
def spspmm(indexA, valueA, indexB, valueB, m, k, n): def spspmm(indexA, valueA, indexB, valueB, m, k, n):
...@@ -60,7 +60,8 @@ def mm(indexA, valueA, indexB, valueB, m, k, n): ...@@ -60,7 +60,8 @@ def mm(indexA, valueA, indexB, valueB, m, k, n):
assert valueA.dtype == valueB.dtype assert valueA.dtype == valueB.dtype
if indexA.is_cuda: if indexA.is_cuda:
return spspmm_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n) return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
m, k, n)
A = to_scipy(indexA, valueA, m, k) A = to_scipy(indexA, valueA, m, k)
B = to_scipy(indexB, valueB, k, n) B = to_scipy(indexB, valueB, k, n)
......
...@@ -2,14 +2,14 @@ import torch ...@@ -2,14 +2,14 @@ import torch
import numpy as np import numpy as np
if torch.cuda.is_available(): if torch.cuda.is_available():
import unique_cuda import torch_sparse.unique_cuda
def unique(src): def unique(src):
src = src.contiguous().view(-1) src = src.contiguous().view(-1)
if src.is_cuda: if src.is_cuda:
out, perm = unique_cuda.unique(src) out, perm = torch_sparse.unique_cuda.unique(src)
else: else:
out, perm = np.unique(src.numpy(), return_index=True) out, perm = np.unique(src.numpy(), return_index=True)
out, perm = torch.from_numpy(out), torch.from_numpy(perm) out, perm = torch.from_numpy(out), torch.from_numpy(perm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment