Commit 6a6e86c7 authored by rusty1s's avatar rusty1s
Browse files
parents 2219f43f d33d29b2
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover
def backward
cuda cuda
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B);
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor A, at::Tensor B) {
CHECK_CUDA(A);
CHECK_CUDA(B);
return spspmm_cuda(A, B);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
}
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, int m, int k, int n);
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB,
int m, int k, int n) {
CHECK_CUDA(indexA);
CHECK_CUDA(valueA);
CHECK_CUDA(indexB);
CHECK_CUDA(valueB);
return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
}
...@@ -27,28 +27,32 @@ static void init_cusparse() { ...@@ -27,28 +27,32 @@ static void init_cusparse() {
} }
} }
std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) { std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, int m, int k, int n) {
init_cusparse(); init_cusparse();
auto m = A.size(0); indexA = indexA.contiguous();
auto k = A.size(1); valueA = valueA.contiguous();
auto n = B.size(1); indexB = indexB.contiguous();
valueB = valueB.contiguous();
auto nnzA = A._nnz(); auto nnzA = valueA.size(0);
auto nnzB = B._nnz(); auto nnzB = valueB.size(0);
auto valueA = A._values(); indexA = indexA.toType(at::kInt);
auto indexA = A._indices().toType(at::kInt); indexB = indexB.toType(at::kInt);
auto row_ptrA = at::empty(indexA.type(), {m + 1});
// Convert A to CSR format.
auto row_ptrA = at::empty(m + 1, indexA.type());
cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k, cusparseXcoo2csr(cusparse_handle, indexA[0].data<int>(), nnzA, k,
row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO); row_ptrA.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colA = indexA[1]; auto colA = indexA[1];
cudaMemcpy(row_ptrA.data<int>() + m, &nnzA, sizeof(int), cudaMemcpy(row_ptrA.data<int>() + m, &nnzA, sizeof(int),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
auto valueB = B._values(); // Convert B to CSR format.
auto indexB = B._indices().toType(at::kInt); auto row_ptrB = at::empty(k + 1, indexB.type());
auto row_ptrB = at::empty(indexB.type(), {k + 1});
cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k, cusparseXcoo2csr(cusparse_handle, indexB[0].data<int>(), nnzB, k,
row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO); row_ptrB.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
auto colB = indexB[1]; auto colB = indexB[1];
...@@ -61,14 +65,14 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) { ...@@ -61,14 +65,14 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO); cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
int nnzC; int nnzC;
auto row_ptrC = at::empty(indexA.type(), {m + 1}); auto row_ptrC = at::empty(m + 1, indexB.type());
cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA, CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
row_ptrA.data<int>(), colA.data<int>(), descr, nnzB, row_ptrA.data<int>(), colA.data<int>(), descr, nnzB,
row_ptrB.data<int>(), colB.data<int>(), descr, row_ptrB.data<int>(), colB.data<int>(), descr,
row_ptrC.data<int>(), &nnzC); row_ptrC.data<int>(), &nnzC);
auto colC = at::empty(indexA.type(), {nnzC}); auto colC = at::empty(nnzC, indexA.type());
auto valueC = at::empty(valueA.type(), {nnzC}); auto valueC = at::empty(nnzC, valueA.type());
CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, CSRGEMM(valueC.type(), cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA, CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
...@@ -77,7 +81,7 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) { ...@@ -77,7 +81,7 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
colB.data<int>(), descr, valueC.data<scalar_t>(), colB.data<int>(), descr, valueC.data<scalar_t>(),
row_ptrC.data<int>(), colC.data<int>()); row_ptrC.data<int>(), colC.data<int>());
auto rowC = at::empty(indexA.type(), {nnzC}); auto rowC = at::empty(nnzC, indexA.type());
cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m, cusparseXcsr2coo(cusparse_handle, row_ptrC.data<int>(), nnzC, m,
rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO); rowC.data<int>(), CUSPARSE_INDEX_BASE_ZERO);
......
...@@ -5,7 +5,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension ...@@ -5,7 +5,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
__version__ = '0.2.0' __version__ = '0.2.0'
url = 'https://github.com/rusty1s/pytorch_sparse' url = 'https://github.com/rusty1s/pytorch_sparse'
install_requires = ['numpy', 'scipy'] install_requires = ['scipy']
setup_requires = ['pytest-runner'] setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov'] tests_require = ['pytest', 'pytest-cov']
ext_modules = [] ext_modules = []
...@@ -13,8 +13,8 @@ cmdclass = {} ...@@ -13,8 +13,8 @@ cmdclass = {}
if torch.cuda.is_available(): if torch.cuda.is_available():
ext_modules += [ ext_modules += [
CUDAExtension('matmul_cuda', CUDAExtension('spspmm_cuda',
['cuda/matmul.cpp', 'cuda/matmul_cuda.cu']) ['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'])
] ]
cmdclass['build_ext'] = BuildExtension cmdclass['build_ext'] = BuildExtension
......
...@@ -8,6 +8,6 @@ def test_coalesce(): ...@@ -8,6 +8,6 @@ def test_coalesce():
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]]) value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
index, value = coalesce(index, value, torch.Size([4, 2])) index, value = coalesce(index, value, m=3, n=2)
assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]] assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]]
assert value.tolist() == [[6, 8], [7, 9], [3, 4], [5, 6]] assert value.tolist() == [[6, 8], [7, 9], [3, 4], [5, 6]]
from itertools import product
import pytest
import torch
from torch_sparse import sparse_coo_tensor, spspmm, to_value
from .utils import dtypes, devices, tensor
tests = [{
'name': 'Test coalesced input',
'indexA': [[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]],
'valueA': [1, 2, 3, 4, 5],
'sizeA': [3, 3],
'indexB': [[0, 2], [1, 0]],
'valueB': [2, 4],
'sizeB': [3, 2],
}, {
'name': 'Test uncoalesced input',
'indexA': [[2, 2, 1, 0, 2, 0], [1, 1, 0, 2, 0, 1]],
'valueA': [3, 2, 3, 2, 4, 1],
'sizeA': [3, 3],
'indexB': [[2, 0, 2], [0, 1, 0]],
'valueB': [2, 2, 2],
'sizeB': [3, 2],
}]
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spspmm(test, dtype, device):
indexA = torch.tensor(test['indexA'], device=device)
valueA = tensor(test['valueA'], dtype, device, requires_grad=True)
sizeA = torch.Size(test['sizeA'])
A = sparse_coo_tensor(indexA, valueA, sizeA)
denseA = A.detach().to_dense().requires_grad_()
indexB = torch.tensor(test['indexB'], device=device)
valueB = tensor(test['valueB'], dtype, device, requires_grad=True)
sizeB = torch.Size(test['sizeB'])
B = sparse_coo_tensor(indexB, valueB, sizeB)
denseB = B.detach().to_dense().requires_grad_()
C = spspmm(A, B)
denseC = torch.matmul(denseA, denseB)
assert C.detach().to_dense().tolist() == denseC.tolist()
to_value(C).sum().backward()
denseC.sum().backward()
assert valueA.grad.tolist() == denseA.grad[indexA[0], indexA[1]].tolist()
import torch
from torch_sparse import spmm
def test_spmm():
row = torch.tensor([0, 0, 1, 2, 2])
col = torch.tensor([0, 2, 1, 0, 1])
index = torch.stack([row, col], dim=0)
value = torch.tensor([1, 2, 4, 1, 3])
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]])
out = spmm(index, value, 3, matrix)
assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
from itertools import product
import pytest
import torch
from torch_sparse import spspmm
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spspmm(dtype, device):
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
sizeA = torch.Size([3, 3])
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
valueB = tensor([2, 4], dtype, device)
sizeB = torch.Size([3, 2])
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
assert valueC.tolist() == [8, 6, 8]
A = torch.sparse_coo_tensor(indexA, valueA, sizeA, device=device)
A = A.to_dense().requires_grad_()
B = torch.sparse_coo_tensor(indexB, valueB, sizeB, device=device)
B = B.to_dense().requires_grad_()
torch.matmul(A, B).sum().backward()
valueA = valueA.requires_grad_()
valueB = valueB.requires_grad_()
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
valueC.sum().backward()
assert valueA.grad.tolist() == A.grad[indexA[0], indexA[1]].tolist()
assert valueB.grad.tolist() == B.grad[indexB[0], indexB[1]].tolist()
import torch
from torch_sparse import transpose
def test_transpose():
row = torch.tensor([1, 0, 1, 0, 2, 1])
col = torch.tensor([0, 1, 1, 1, 0, 0])
index = torch.stack([row, col], dim=0)
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
...@@ -3,10 +3,9 @@ import torch ...@@ -3,10 +3,9 @@ import torch
dtypes = [torch.float, torch.double] dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): # pragma: no cover if torch.cuda.is_available():
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))] devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
def tensor(x, dtype, device, requires_grad=False): def tensor(x, dtype, device):
return torch.tensor( return torch.tensor(x, dtype=dtype, device=device)
x, dtype=dtype, device=device, requires_grad=requires_grad)
from .coalesce import coalesce from .coalesce import coalesce
from .transpose import transpose from .transpose import transpose
from .matmul import spspmm from .spmm import spmm
from .spspmm import spspmm
__version__ = '0.2.0' __version__ = '0.2.0'
...@@ -8,5 +9,6 @@ __all__ = [ ...@@ -8,5 +9,6 @@ __all__ = [
'__version__', '__version__',
'coalesce', 'coalesce',
'transpose', 'transpose',
'spmm',
'spspmm', 'spspmm',
] ]
...@@ -3,6 +3,8 @@ import torch_scatter ...@@ -3,6 +3,8 @@ import torch_scatter
def coalesce(index, value, m, n, op='add', fill_value=0): def coalesce(index, value, m, n, op='add', fill_value=0):
"""Row-wise reorders and removes duplicate entries in sparse matrixx."""
row, col = index row, col = index
unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True) unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
......
import torch
class SparseCooTensor(torch.autograd.Function):
"""Constructs Sparse matrix with autograd capabilities w.r.t. to value."""
@staticmethod
def forward(ctx, index, value, size):
ctx.size = size
ctx.save_for_backward(index)
return torch.sparse_coo_tensor(index, value, size, device=value.device)
@staticmethod
def backward(ctx, grad_out):
index = ctx.saved_variables[0]
grad_in = None
if ctx.needs_input_grad[1]:
value = grad_out._values()
id1 = index[0] * ctx.size[1] + index[1]
index = grad_out._indices()
id2 = index[0] * ctx.size[1] + index[1]
grad_in = value.new_zeros(id1.max().item() + 1)
grad_in[id2] = value
grad_in = grad_in[id1]
return None, grad_in, None
sparse_coo_tensor = SparseCooTensor.apply
class ToValue(torch.autograd.Function):
"""Extract values of sparse tensors with autograd support."""
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return A._values()
@staticmethod
def backward(ctx, grad_out):
A = ctx.saved_variables[0]
grad_in = None
if ctx.needs_input_grad[0]:
grad_in = torch.sparse_coo_tensor(
A._indices(), grad_out, A.size(), device=grad_out.device)
return grad_in
to_value = ToValue.apply
from torch_scatter import scatter_add
def spmm(index, value, m, matrix):
"""Matrix product of sparse matrix with dense matrix."""
row, col = index
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
out = matrix[col]
out = out * value.unsqueeze(-1)
out = scatter_add(out, row, dim=0, dim_size=m)
return out
import torch import torch
from torch import from_numpy
import scipy.sparse import scipy.sparse
from torch_sparse import transpose from torch_sparse import transpose
if torch.cuda.is_available(): if torch.cuda.is_available():
import matmul_cuda import spspmm_cuda
class SpSpMM(torch.autograd.Function): class SpSpMM(torch.autograd.Function):
"""Sparse matrix product of two sparse tensors with autograd support.""" """Sparse matrix product of two sparse matrices with autograd support."""
@staticmethod @staticmethod
def forward(ctx, indexA, valueA, indexB, valueB, m, k, n): def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
...@@ -24,14 +25,16 @@ class SpSpMM(torch.autograd.Function): ...@@ -24,14 +25,16 @@ class SpSpMM(torch.autograd.Function):
grad_valueA = grad_valueB = None grad_valueA = grad_valueB = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
indexB, valueB = transpose(indexB, valueB, k, n) indexB_T, valueB_T = transpose(indexB, valueB, k, n)
_, grad_valueA = mm(indexC, grad_valueC, indexB, valueB, m, n, k) grad_indexA, grad_valueA = mm(indexC, grad_valueC, indexB_T,
# TODO: Filter values. valueB_T, m, n, k)
grad_valueA = lift(grad_indexA, grad_valueA, indexA, k)
if ctx.needs_input_grad[4]: if ctx.needs_input_grad[3]:
indexA, valueA = transpose(indexA, valueA, m, k) indexA_T, valueA_T = transpose(indexA, valueA, m, k)
_, grad_valueB = mm(indexA, valueA, indexC, grad_valueC, k, m, n) grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
# TODO: Filter values. grad_valueC, k, m, n)
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
return None, grad_valueA, None, grad_valueB, None, None, None return None, grad_valueA, None, grad_valueB, None, None, None
...@@ -43,7 +46,7 @@ def mm(indexA, valueA, indexB, valueB, m, k, n): ...@@ -43,7 +46,7 @@ def mm(indexA, valueA, indexB, valueB, m, k, n):
assert valueA.dtype == valueB.dtype assert valueA.dtype == valueB.dtype
if indexA.is_cuda: if indexA.is_cuda:
return matmul_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n) return spspmm_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n)
A = to_scipy(indexA, valueA, m, k) A = to_scipy(indexA, valueA, m, k)
B = to_scipy(indexB, valueB, k, n) B = to_scipy(indexB, valueB, k, n)
...@@ -58,6 +61,17 @@ def to_scipy(index, value, m, n): ...@@ -58,6 +61,17 @@ def to_scipy(index, value, m, n):
def from_scipy(A): def from_scipy(A):
row, col, value = A.row, A.col, A.data row, col, value = from_numpy(A.row), from_numpy(A.col), from_numpy(A.data)
index = torch.stack([row, col], dim=0).to(torch.long) index = torch.stack([row, col], dim=0).to(torch.long)
return index, value return index, value
def lift(indexA, valueA, indexB, n):
indexA = indexA[0] * n + indexA[1]
indexB = indexB[0] * n + indexB[1]
value = valueA.new_zeros(indexB.max().item() + 1)
value[indexA] = valueA
value = value[indexB]
return value
...@@ -3,9 +3,11 @@ from torch_sparse import coalesce ...@@ -3,9 +3,11 @@ from torch_sparse import coalesce
def transpose(index, value, m, n): def transpose(index, value, m, n):
"""Transpose of sparse matrix."""
row, col = index row, col = index
index = torch.stack([col, row], dim=0) index = torch.stack([col, row], dim=0)
index, value = coalesce(index, value, m, n) index, value = coalesce(index, value, n, m)
return index, value return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment