"docs/vscode:/vscode.git/clone" did not exist on "d7117b95ab120230bb7dc6e69c7c4c800397fcbf"
Commit 238efb11 authored by rusty1s's avatar rusty1s
Browse files

faster spspmm backward + cleanup

parent 5586d7ae
#include <torch/extension.h>
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
// Assert already coalesced input.
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
size_t rowB_max) {
int64_t *index_data = index.data<int64_t>();
auto value = at::zeros(index.size(1), valueA.options());
at::Tensor rowA, colA;
std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
int64_t *rowA_data = rowA.data<int64_t>();
int64_t *colA_data = colA.data<int64_t>();
at::Tensor rowB, colB;
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
int64_t *rowB_data = rowB.data<int64_t>();
int64_t *colB_data = colB.data<int64_t>();
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
scalar_t *value_data = value.data<scalar_t>();
scalar_t *valueA_data = valueA.data<scalar_t>();
scalar_t *valueB_data = valueB.data<scalar_t>();
for (int64_t e = 0; e < value.size(0); e++) {
int64_t i = index_data[e], j = index_data[value.size(0) + e];
for (ptrdiff_t dA = rowA_data[i]; dA < rowA_data[i + 1]; dA++) {
int64_t cA = colA_data[dA];
for (ptrdiff_t dB = rowB_data[j]; dB < rowB_data[j + 1]; dB++) {
int64_t cB = colB_data[dB];
if (cA == cB) {
value_data[e] += valueA_data[dA] * valueB_data[dB];
}
if (cB >= cA) {
break;
}
}
}
}
});
return value;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm_bw", &spspmm_bw,
"Sparse-Sparse Matrix Multiplication Backward (CPU)");
}
......@@ -4,11 +4,14 @@
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, int m, int k, int n);
at::Tensor valueB, size_t m, size_t k, size_t n);
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t rowA_max, size_t rowB_max);
std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB,
int m, int k, int n) {
size_t m, size_t k, size_t n) {
CHECK_CUDA(indexA);
CHECK_CUDA(valueA);
CHECK_CUDA(indexB);
......@@ -16,6 +19,20 @@ std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor valueA,
return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
}
at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
size_t rowB_max) {
CHECK_CUDA(index);
CHECK_CUDA(indexA);
CHECK_CUDA(valueA);
CHECK_CUDA(indexB);
CHECK_CUDA(valueB);
return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
rowB_max);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spspmm", &spspmm, "Sparse-Sparse Matrix Multiplication (CUDA)");
m.def("spspmm_bw", &spspmm_bw,
"Sparse-Sparse Matrix Multiplication Backward (CUDA)");
}
......@@ -2,6 +2,9 @@
#include <cusparse.h>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define CSRGEMM(TYPE, ...) \
[&] { \
const at::Type &the_type = TYPE; \
......@@ -29,7 +32,7 @@ static void init_cusparse() {
std::tuple<at::Tensor, at::Tensor>
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, int m, int k, int n) {
at::Tensor valueB, size_t m, size_t k, size_t n) {
cudaSetDevice(indexA.get_device());
init_cusparse();
......@@ -90,3 +93,69 @@ spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
return std::make_tuple(indexC, valueC);
}
at::Tensor degree(at::Tensor row, int64_t num_nodes) {
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
int64_t num_nodes) {
// Assert already coalesced input.
row = degree(row, num_nodes).cumsum(0);
row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
return std::make_tuple(row, col);
}
template <typename scalar_t>
__global__ void spspmm_bw_kernel(
const int64_t *__restrict__ index, scalar_t *__restrict__ value,
const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
const scalar_t *__restrict__ valueA, const int64_t *__restrict__ rowB,
const int64_t *__restrict__ colB, const scalar_t *__restrict__ valueB,
const size_t numel) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t e = idx; e < numel; e += stride) {
int64_t i = index[e], j = index[numel + e];
for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
int64_t cA = colA[dA];
for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
int64_t cB = colB[dB];
if (cA == cB) {
value[e] += valueA[dA] * valueB[dB];
}
if (cB >= cA) {
break;
}
}
}
}
}
at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
at::Tensor valueA, at::Tensor indexB,
at::Tensor valueB, size_t rowA_max, size_t rowB_max) {
cudaSetDevice(index.get_device());
auto value = at::zeros(index.size(1), valueA.options());
at::Tensor rowA, colA;
std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
at::Tensor rowB, colB;
std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
AT_DISPATCH_FLOATING_TYPES(valueA.type(), "spspmm_bw", [&] {
spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
index.data<int64_t>(), value.data<scalar_t>(), rowA.data<int64_t>(),
colA.data<int64_t>(), valueA.data<scalar_t>(), rowB.data<int64_t>(),
colB.data<int64_t>(), valueB.data<scalar_t>(), value.numel());
});
return value;
}
import platform
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
__version__ = '0.2.4'
__version__ = '0.3.0'
url = 'https://github.com/rusty1s/pytorch_sparse'
install_requires = ['scipy']
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
ext_modules = []
ext_modules = [CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'])]
cmdclass = {}
if CUDA_HOME is not None:
......@@ -25,7 +26,7 @@ if CUDA_HOME is not None:
CUDAExtension('torch_sparse.unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu']),
]
cmdclass['build_ext'] = BuildExtension
cmdclass['build_ext'] = torch.utils.cpp_extension.BuildExtension
setup(
name='torch_sparse',
......
from itertools import product
import pytest
import torch
from torch_sparse import transpose
from torch_sparse import transpose, transpose_matrix
from .utils import dtypes, devices, tensor
def test_transpose():
......@@ -11,3 +16,15 @@ def test_transpose():
index, value = transpose(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_transpose_matrix(dtype, device):
row = torch.tensor([1, 0, 1, 2], device=device)
col = torch.tensor([0, 1, 1, 0], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 3, 4], dtype, device)
index, value = transpose_matrix(index, value, m=3, n=2)
assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]]
assert value.tolist() == [1, 4, 2, 3]
from .convert import to_scipy, from_scipy
from .coalesce import coalesce
from .transpose import transpose
from .transpose import transpose, transpose_matrix
from .eye import eye
from .spmm import spmm
from .spspmm import spspmm
__version__ = '0.2.4'
__version__ = '0.3.0'
__all__ = [
'__version__',
'to_scipy',
'from_scipy',
'coalesce',
'transpose',
'transpose_matrix',
'eye',
'spmm',
'spspmm',
......
import numpy as np
import scipy.sparse
import torch
from torch import from_numpy
def to_scipy(index, value, m, n):
assert not index.is_cuda and not value.is_cuda
(row, col), data = index.detach(), value.detach()
return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
def from_scipy(A):
A = A.tocoo()
row, col, value = A.row.astype(np.int64), A.col.astype(np.int64), A.data
row, col, value = from_numpy(row), from_numpy(col), from_numpy(value)
index = torch.stack([row, col], dim=0)
return index, value
import torch
from torch import from_numpy
import numpy as np
import scipy.sparse
from torch_sparse import transpose
from torch_sparse import transpose_matrix, to_scipy, from_scipy
import torch_sparse.spspmm_cpu
if torch.cuda.is_available():
import torch_sparse.spspmm_cuda
......@@ -38,22 +37,39 @@ class SpSpMM(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_indexC, grad_valueC):
m, k, n = ctx.m, ctx.k, ctx.n
m, k = ctx.m, ctx.k
n = ctx.n
indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
grad_valueA = grad_valueB = None
if ctx.needs_input_grad[1]:
indexB_T, valueB_T = transpose(indexB, valueB, k, n)
grad_indexA, grad_valueA = mm(indexC, grad_valueC, indexB_T,
valueB_T, m, n, k)
grad_valueA = lift(grad_indexA, grad_valueA, indexA, k)
if ctx.needs_input_grad[3]:
indexA_T, valueA_T = transpose(indexA, valueA, m, k)
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
grad_valueC, k, m, n)
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
if not grad_valueC.is_cuda:
if ctx.needs_input_grad[1] or ctx.needs_input_grad[1]:
grad_valueC = grad_valueC.clone()
if ctx.needs_input_grad[1]:
grad_valueA = torch_sparse.spspmm_cpu.spspmm_bw(
indexA, indexC.detach(), grad_valueC, indexB.detach(),
valueB, m, k)
if ctx.needs_input_grad[3]:
indexA, valueA = transpose_matrix(indexA, valueA, m, k)
indexC, grad_valueC = transpose_matrix(indexC, grad_valueC, m,
n)
grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
indexB, indexA.detach(), valueA, indexC.detach(),
grad_valueC, k, n)
else:
if ctx.needs_input_grad[1]:
grad_valueA = torch_sparse.spspmm_cuda.spspmm_bw(
indexA, indexC.detach(), grad_valueC.clone(),
indexB.detach(), valueB, m, k)
if ctx.needs_input_grad[3]:
indexA_T, valueA_T = transpose_matrix(indexA, valueA, m, k)
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
grad_valueC, k, m, n)
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
return None, grad_valueA, None, grad_valueB, None, None, None
......@@ -67,23 +83,11 @@ def mm(indexA, valueA, indexB, valueB, m, k, n):
A = to_scipy(indexA, valueA, m, k)
B = to_scipy(indexB, valueB, k, n)
indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
C = A.dot(B).tocoo().tocsr().tocoo() # Force coalesce.
indexC, valueC = from_scipy(C)
return indexC, valueC
def to_scipy(index, value, m, n):
(row, col), data = index.detach(), value.detach()
return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
def from_scipy(A):
row, col, value = A.row.astype(np.int64), A.col.astype(np.int64), A.data
row, col, value = from_numpy(row), from_numpy(col), from_numpy(value)
index = torch.stack([row, col], dim=0)
return index, value
def lift(indexA, valueA, indexB, n): # pragma: no cover
idxA = indexA[0] * n + indexA[1]
idxB = indexB[0] * n + indexB[1]
......
import torch
from torch_sparse import coalesce
from torch_sparse import to_scipy, from_scipy, coalesce
def transpose(index, value, m, n):
"""Transposes dimensions 0 and 1 of a sparse matrix.
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
......@@ -16,7 +16,29 @@ def transpose(index, value, m, n):
row, col = index
index = torch.stack([col, row], dim=0)
index, value = coalesce(index, value, n, m)
return index, value
def transpose_matrix(index, value, m, n):
"""Transposes dimensions 0 and 1 of a sparse matrix, where :args:`value` is
one-dimensional.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
assert value.dim() == 1
if not index.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
else:
return transpose(index, value, m, n)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment