Commit 37a8124e authored by rusty1s's avatar rusty1s
Browse files

matmul cuda boilerplate

parent 0ae0e784
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
at::SparseTensor spspmm(at::SparseTensor matrix1, at::SparseTensor matrix2) {
return matrix1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
}
import torch import torch
from setuptools import setup, find_packages from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
__version__ = '1.0.1' __version__ = '1.0.1'
url = 'https://github.com/rusty1s/pytorch_sparse' url = 'https://github.com/rusty1s/pytorch_sparse'
...@@ -11,7 +12,8 @@ ext_modules = [] ...@@ -11,7 +12,8 @@ ext_modules = []
cmdclass = {} cmdclass = {}
if torch.cuda.is_available(): if torch.cuda.is_available():
pass ext_modules += [CUDAExtension('matmul_cuda', ['cuda/matmul.cpp'])]
cmdclass['build_ext'] = BuildExtension
setup( setup(
name='torch_sparse', name='torch_sparse',
......
...@@ -2,6 +2,8 @@ import torch ...@@ -2,6 +2,8 @@ import torch
from torch import from_numpy from torch import from_numpy
from scipy.sparse import coo_matrix from scipy.sparse import coo_matrix
import matmul_cuda
class SpSpMM(torch.autograd.Function): class SpSpMM(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -34,12 +36,25 @@ spspmm = SpSpMM.apply ...@@ -34,12 +36,25 @@ spspmm = SpSpMM.apply
def mm(e1, v1, s1, e2, v2, s2): def mm(e1, v1, s1, e2, v2, s2):
if e1.is_cuda: if v1.is_cuda:
pass return mm_cuda(e1, v1, s1, e2, v2, s2)
else: else:
return mm_cpu(e1, v1, s1, e2, v2, s2) return mm_cpu(e1, v1, s1, e2, v2, s2)
def mm_cuda(e1, v1, s1, e2, v2, s2):
matrix1 = to_sparse(e1, v1, s1)
matrix2 = to_sparse(e2, v2, s2)
out = matmul_cuda.spspmm(matrix1, matrix2)
return out._indices(), out._values()
def to_sparse(index, value, size):
assert value.is_cuda
SparseTensor = getattr(torch.cuda.sparse, value.type().split('.')[-1])
return SparseTensor(index, value, size)
def mm_cpu(e1, v1, s1, e2, v2, s2): def mm_cpu(e1, v1, s1, e2, v2, s2):
matrix1, matrix2, = to_csr(e1, v1, s1), to_csr(e2, v2, s2) matrix1, matrix2, = to_csr(e1, v1, s1), to_csr(e2, v2, s2)
out = matrix1.dot(matrix2).tocoo() out = matrix1.dot(matrix2).tocoo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment