Commit 06933f89 authored by rusty1s's avatar rusty1s
Browse files

cuda sparse sparse mm implementation

parent 41458598
import torch
from torch.autograd import Variable
size = torch.Size([2, 2])
index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long).cuda()
value = torch.tensor([1, 1], dtype=torch.float).cuda()
A = torch.cuda.sparse.FloatTensor(index, value, size)
index = torch.tensor([[0, 1], [0, 1]], dtype=torch.long)
value = torch.tensor([1, 1], dtype=torch.float)
B = torch.sparse.FloatTensor(index, value, size)
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
at::Tensor spspmm_cuda(at::Tensor matrix1, at::Tensor matrix2); std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B);
at::Tensor spspmm(at::Tensor matrix1, at::Tensor matrix2) { std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor A, at::Tensor B) {
CHECK_CUDA(matrix1); CHECK_CUDA(A);
CHECK_CUDA(matrix2); CHECK_CUDA(B);
return spspmm_cuda(matrix1, matrix2); return spspmm_cuda(A, B);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -2,6 +2,23 @@ ...@@ -2,6 +2,23 @@
#include <cusparse.h> #include <cusparse.h>
#define CSRGEMM(TYPE, ...) \
[&] { \
const at::Type &the_type = TYPE; \
switch (the_type.scalarType()) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
return cusparseScsrgemm(__VA_ARGS__); \
} \
case at::ScalarType::Double: { \
using scalar_t = double; \
return cusparseDcsrgemm(__VA_ARGS__); \
} \
default: \
AT_ERROR("Not implemented for '%s'", the_type.toString()); \
} \
}()
static cusparseHandle_t cusparse_handle = 0; static cusparseHandle_t cusparse_handle = 0;
static void init_cusparse() { static void init_cusparse() {
...@@ -10,51 +27,57 @@ static void init_cusparse() { ...@@ -10,51 +27,57 @@ static void init_cusparse() {
} }
} }
at::Tensor spspmm_cuda(at::Tensor matrix1, at::Tensor matrix2) { std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
init_cusparse(); init_cusparse();
auto nnz = matrix1._nnz(); auto m = A.size(0);
auto inDim = matrix1.size(1); auto n = B.size(1);
auto k = A.size(1);
auto row = matrix1._indices()[0].toType(at::kInt);
auto row_ptrs = at::empty(row.type(), {inDim + 1}); auto nnzA = A._nnz();
auto nnzB = B._nnz();
cusparseXcoo2csr(cusparse_handle, row.data<int>(), nnz, inDim,
row_ptrs.data<int>(), CUSPARSE_INDEX_BASE_ZERO); auto valueA = A._values();
auto indexA = A._indices().toType(at::kInt);
printf("%lli\n", nnz); auto row_ptrA = at::empty(indexA.type(), {m + 1});
printf("%lli\n", inDim); cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k,
row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
/* colbuf at::empty(nnz); */ auto colA = indexA[1];
/* auto colPtrs = at::empty(inDim + 1, at::kInt); */
auto valueB = B._values();
/* auto row = matrix1._indices(); */ auto indexB = B._indices().toType(at::kInt);
/* for (int i = 0; i < 5; i++) { */ auto row_ptrB = at::empty(indexB.type(), {k + 1});
/* row_buf.data<int>()[i] = (int)row.data<int64_t>()[i]; */ cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k,
/* } */ row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
/* printf("%lli\n", row.numel()); */ auto colB = indexB[1];
return matrix1; cusparseMatDescr_t descr = 0;
cusparseCreateMatDescr(&descr);
cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
int nnzC;
auto row_ptrC = at::empty(indexA.type(), {m + 1});
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
row_ptrA.data<int>(), colA.data<int>(), descr, nnzB,
row_ptrB.data<int>(), colB.data<int>(), descr,
row_ptrC.data<int>(), &nnzC);
auto colC = at::empty(indexA.type(), {nnzC});
auto valueC = at::empty(valueA.type(), {nnzC});
CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
valueA.data<scalar_t>(), row_ptrA.data<int>(), colA.data<int>(),
descr, nnzB, valueB.data<scalar_t>(), row_ptrB.data<int>(),
colB.data<int>(), descr, valueC.data<scalar_t>(),
row_ptrC.data<int>(), colC.data<int>());
auto rowC = at::empty(indexA.type(), {nnzC});
cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m,
rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);
return std::make_tuple(indexC, valueC);
} }
/* #include <ATen/SparseTensorImpl.h> */
/* namespace at { */
/* namespace native { */
/* using SparseTensor = Tensor; */
/* namespace { */
/* at::SparseTensor spspmm_cuda(at::SparseTensor matrix1, */
/* at::SparseTensor matrix2) { */
/* return matrix1; */
/* } */
/* } // namespace */
/* } // namespace native */
/* } // namespace at */
// defined in aten/src/THCUNN/SparseLinear.cu as
/* cusparseXcoo2csr(cusparse_handle, THCudaIntTensor_data(state, colbuf), nnz,
*/
/* inDim, THCudaIntTensor_data(state, colPtrs), */
/* CUSPARSE_INDEX_BASE_ONE); */
...@@ -2,31 +2,24 @@ from itertools import product ...@@ -2,31 +2,24 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import spspmm from torch_sparse import spspmm, SparseTensor
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cuda')]
dtypes = [torch.float]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spspmm(dtype, device): def test_spspmm(dtype, device):
e1 = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device) index = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
v1 = tensor([1, 2, 3, 4, 5], dtype, device) value = tensor([1, 2, 3, 4, 5], dtype, device)
matrix1 = (e1, v1, torch.Size([3, 3])) A = (index, value, torch.Size([3, 3]))
e2 = torch.tensor([[0, 2], [1, 0]], device=device) index = torch.tensor([[0, 2], [1, 0]], device=device)
v2 = tensor([2, 4], dtype, device) value = tensor([2, 4], dtype, device)
matrix2 = (e2, v2, torch.Size([3, 2])) B = (index, value, torch.Size([3, 2]))
index, value = spspmm(*matrix1, *matrix2) index, value = spspmm(*A, *B)
print(index) out = SparseTensor(index, value, torch.Size([3, 2]))
print(value) assert out.to_dense().tolist() == [[8, 0], [0, 6], [0, 8]]
# out = torch.sparse_coo_tensor(index, value, torch.Size([3, 2]), dtype)
# out = out.to_dense()
# print(out)
# assert out.tolist() == [[8, 0], [0, 6], [0, 8]]
# value.sum().backward()
# TODO TEST backward # TODO TEST backward
# value.sum().backward()
...@@ -46,8 +46,7 @@ def mm(e1, v1, s1, e2, v2, s2): ...@@ -46,8 +46,7 @@ def mm(e1, v1, s1, e2, v2, s2):
def mm_cuda(e1, v1, s1, e2, v2, s2): def mm_cuda(e1, v1, s1, e2, v2, s2):
matrix1 = SparseTensor(e1, v1, s1) matrix1 = SparseTensor(e1, v1, s1)
matrix2 = SparseTensor(e2, v2, s2) matrix2 = SparseTensor(e2, v2, s2)
out = matmul_cuda.spspmm(matrix1, matrix2) return matmul_cuda.spspmm(matrix1, matrix2)
return out._indices(), out._values()
def mm_cpu(e1, v1, s1, e2, v2, s2): def mm_cpu(e1, v1, s1, e2, v2, s2):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment