Commit df5f7063 authored by rusty1s's avatar rusty1s
Browse files

spmm backward implementation

parent b3187f23
......@@ -101,7 +101,6 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
if (value_opt.has_value())
AT_ASSERTM(value_opt.value().dim() == 1);
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(rowptr.numel() - 1 == mat.size(-2), "Input mismatch");
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
......@@ -110,26 +109,26 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, mat.size(-2), rowptr.options());
arg_out = at::full_like(out, -1, rowptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
int N = rowptr.numel() - 1;
int M = mat.size(-2);
int K = mat.size(-1);
int B = mat.numel() / (M * K);
auto N = rowptr.numel() - 1;
auto M = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (M * K);
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr;
auto mat_data = out.DATA_PTR<scalar_t>();
auto out_data = mat.DATA_PTR<scalar_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, col_idx;
int64_t row_start, row_end, c;
std::vector<int64_t> args(K);
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
......@@ -147,18 +146,17 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
int offset = b * M * K;
for (int e = row_start; e < row_end; e++) {
col_idx = col_data[e];
c = col_data[e];
if (HAS_VAL)
val = value_data[e];
for (int k = 0; k < K; k++) {
if (HAS_VAL)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + col_idx * K + k],
&args[k], e);
&vals[k], val * mat_data[offset + c * K + k], &args[k],
e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + col_idx * K + k], &args[k],
e);
&vals[k], mat_data[offset + c * K + k], &args[k], e);
}
}
offset = b * N * K + n * K;
......@@ -175,6 +173,56 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return std::make_tuple(out, arg_out);
}
at::Tensor spmm_val_bw(at::Tensor rowptr, at::Tensor col, at::Tensor mat,
at::Tensor grad, std::string reduce) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(mat);
CHECK_CPU(grad);
mat = mat.contiguous();
auto M = rowptr.numel() - 1;
auto N = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto out = at::zeros(col.sizes(), grad.options());
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] {
auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t val;
int64_t row_start, row_end, c;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) {
for (int m = 0; m < M; m++) {
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (int e = row_start; e < row_end; e++) {
c = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + c * K + k] *
grad_data[b * M * K + m * K + k];
}
if (REDUCE == MEAN)
val = val / (scalar_t)(row_end - row_start);
out_data[e] += val;
}
}
}
});
});
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse-Dense Matrix Multiplication (CPU)");
m.def("spmm_val_bw", &spmm_val_bw,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)");
}
from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor
import torch_scatter
from .utils import tensor, devices, dtypes
devices = ['cpu']
dtypes = [torch.float]
reductions = ['sum', 'mean', 'min', 'max']
# grad_reductions = ['sum', 'mean']
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_forward(dtype, device):
src_dense = torch.randn((5, 4), dtype=dtype, device=device)
src = SparseTensor.from_dense(src_dense)
src.requires_grad_()
src_dense = src_dense.clone().requires_grad_()
other = torch.randn((4, 8), dtype=dtype, device=device)
other.requires_grad_()
out1 = matmul(src, other)
grad_out = torch.randn_like(out1)
out1.backward(grad_out)
other.grad = None
out2 = torch.matmul(src_dense, other)
out2.backward(grad_out)
# assert torch.allclose(out1, out2)
# assert torch.allclose(src.storage.value.grad.view(5, 4), src_dense.grad)
@pytest.mark.parametrize('dtype,device,reduce',
product(dtypes, devices, reductions))
def test_spmm(dtype, device, reduce):
src = torch.ones((5, 4), dtype=dtype, device=device)
src[2] = 0
src = SparseTensor.from_dense(src).requires_grad_()
src.set_value_(None)
other = torch.randn((2, 4, 2), dtype=dtype, device=device,
requires_grad=True)
(row, col), value = src.coo()
out1 = other.index_select(-2, col) # * value.unsqueeze(-1)
func = 'add' if reduce == 'sum' else reduce
out1 = getattr(torch_scatter, f'scatter_{func}')(out1, row, dim=-2)
out1 = out1[0] if isinstance(out1, tuple) else out1
grad_out = torch.randn_like(out1)
out1.backward(grad_out)
# grad_value1 = value.grad
# value.grad = None
grad_other1 = other.grad
other.grad = None
print(reduce)
out2 = matmul(src, other, reduce)
out2 = out2[0] if isinstance(out2, tuple) else out2
out2.backward(grad_out)
# grad_value2 = value.grad
# value.grad = None
grad_other2 = other.grad
other.grad = None
# assert torch.allclose(out1, out2)
# assert torch.allclose(grad_value1, grad_value2)
assert torch.allclose(grad_other1, grad_other2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_backward(dtype, device):
src_dense = torch.randn((5, 4), dtype=torch.double, device=device)
src = SparseTensor.from_dense(src_dense)
src.requires_grad_()
other = torch.randn((4, 8), dtype=torch.double, device=device)
other.requires_grad_()
# assert gradcheck(matmul, (src, other, "sum"))
import torch
from torch_sparse import spmm_cpu
from torch_scatter import scatter_add
try:
from torch_sparse import spmm_cuda
except ImportError:
spmm_cuda = None
def spmm(is_cuda):
return spmm_cuda if is_cuda else spmm_cpu
class SPMM(torch.autograd.Function):
@staticmethod
def forward(ctx, index, rowcount, rowptr, colptr, csr2csc, value, mat,
reduce):
out, arg_out = spmm(mat.is_cuda).spmm(rowptr, index[1], value, mat,
reduce)
ctx.reduce = reduce
ctx.save_for_backward(index, rowcount, rowptr, colptr, csr2csc, value,
mat, arg_out)
if reduce == 'min' or reduce == 'max':
return out, arg_out
else:
return out
@staticmethod
def backward(ctx, grad_out, *args):
data = ctx.saved_tensors
index, rowcount, rowptr, colptr, csr2csc, value, mat, arg_out = data
grad_value = None
if ctx.needs_input_grad[5]:
if ctx.reduce in ['sum', 'add']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rowptr, index[1], mat, grad_out, ctx.reduce)
if ctx.reduce == 'mean':
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rowptr, index[1], mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']:
col = index[1][arg_out.flatten()].view_as(arg_out)
out = mat.gather(-2, col).mul_(grad_out)
out.masked_fill_(arg_out == -1, 0)
col = col.add_(rowptr[:-1].view(-1, 1))
grad_value = scatter_add(out.flatten(), col.flatten(), dim=0,
dim_size=value.numel())
grad_mat = None
if ctx.needs_input_grad[6]:
if ctx.reduce in ['sum', 'add']:
row = index[0][csr2csc]
value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm(
colptr, row, value, grad_out, 'sum')
elif ctx.reduce == 'mean':
count = rowcount[index[0]].to(mat.dtype).clamp_(min=1)
value = count.pow_(-1) if value is None else value / count
row = index[0][csr2csc]
value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm(
colptr, row, value, grad_out, 'sum')
elif ctx.reduce in ['min', 'max']:
if value is not None:
value = value[arg_out.flatten()].view_as(arg_out)
value = value.mul_(grad_out)
else:
value = grad_out
value.masked_fill_(arg_out == -1, 0)
col = index[1][arg_out.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col, dim=-2,
dim_size=mat.size(-2))
return None, None, None, None, None, grad_value, grad_mat, None
def matmul(src, other, reduce='sum'):
assert src.dim() == 2 and src.size(-1) == other.size(-2)
def matmul(src, other, reduce='add'):
if torch.is_tensor(other):
pass
if isinstance(other, src.__class__):
if reduce != 'add':
raise NotImplementedError(
(f'Reduce argument "{reduce}" not implemented for sparse-'
f'sparse matrix multiplication'))
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
(index, value), rowptr = src.coo(), src.storage.rowptr
csr2csc = colptr = None
if other.requires_grad and reduce in ['sum', 'add', 'mean']:
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
rowcount = None
if other.requires_grad and reduce in ['mean']:
rowcount = src.storage.rowcount
return SPMM.apply(index, rowcount, rowptr, colptr, csr2csc, value,
other, reduce)
elif isinstance(other, src.__class__):
assert reduce in ['sum', 'add']
raise ValueError
......@@ -3,15 +3,14 @@ import torch_scatter
from torch_scatter import segment_csr
def reduction(src, dim=None, reduce='add', deterministic=False):
assert reduce in ['add', 'mean', 'min', 'max']
def reduction(src, dim=None, reduce='sum', deterministic=False):
assert reduce in ['sum', 'mean', 'min', 'max']
if dim is None and src.has_value():
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
return func(src.storage.value)
return getattr(torch, reduce)(src.storage.value)
if dim is None and not src.has_value():
value = src.nnz() if reduce == 'add' else 1
value = src.nnz() if reduce == 'sum' else 1
return torch.tensor(value, device=src.device)
dims = [dim] if isinstance(dim, int) else dim
......@@ -24,25 +23,22 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
dense_dims = tuple(set([d - 1 for d in dims if d > 1]))
if len(sparse_dims) == 2 and src.has_value():
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
return func(value, dim=(0, ) + dense_dims)
return getattr(torch, reduce)(value, dim=(0, ) + dense_dims)
if len(sparse_dims) == 2 and not src.has_value():
value = src.nnz() if reduce == 'add' else 1
value = src.nnz() if reduce == 'sum' else 1
return torch.tensor(value, device=src.device)
if len(dense_dims) > 0 and len(sparse_dims) == 0: # src.has_value()
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = func(value, dim=dense_dims)
value = getattr(torch, reduce)(value, dim=dense_dims)
if isinstance(value, tuple):
return (src.set_value(value[0], layout='csr'), ) + value[1:]
return src.set_value(value, layout='csr')
if len(dense_dims) > 0 and len(sparse_dims) > 0:
func = getattr(torch, 'sum' if reduce == 'add' else reduce)
dense_dims = dense_dims[0] if len(dense_dims) == 1 else dense_dims
value = func(value, dim=dense_dims)
value = getattr(torch, reduce)(value, dim=dense_dims)
value = value[0] if isinstance(value, tuple) else value
if sparse_dims[0] == 1 and src.has_value():
......@@ -51,7 +47,7 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
return out
if sparse_dims[0] == 1 and not src.has_value():
if reduce == 'add':
if reduce == 'sum':
return src.storage.rowcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
......@@ -68,13 +64,14 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
return out
if sparse_dims[0] == 0 and src.has_value():
reduce = 'add' if reduce == 'sum' else reduce
func = getattr(torch_scatter, f'scatter_{reduce}')
out = func(value, col, dim=0, dim_size=src.sparse_size(1))
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
if sparse_dims[0] == 0 and not src.has_value():
if reduce == 'add':
if reduce == 'sum':
return src.storage.colcount.to(torch.get_default_dtype())
elif reduce == 'min' or 'max':
# Return an additional `None` arg(min|max) tensor for consistency.
......@@ -84,7 +81,7 @@ def reduction(src, dim=None, reduce='add', deterministic=False):
def sum(src, dim=None, deterministic=False):
return reduction(src, dim, reduce='add', deterministic=deterministic)
return reduction(src, dim, reduce='sum', deterministic=deterministic)
def mean(src, dim=None, deterministic=False):
......
......@@ -164,8 +164,9 @@ class SparseStorage(object):
value = torch.full((self.nnz(), ), device=self.index.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr]
assert value.device == self.index.device
assert value.size(0) == self.index.size(1)
if torch.is_tensor(value):
assert value.device == self.index.device
assert value.size(0) == self.index.size(1)
self._value = value
return self
......@@ -268,7 +269,7 @@ class SparseStorage(object):
@cached_property
def colptr(self):
if self._csr2csc:
if self.has_csr2csc():
func = rowptr_cuda if self.index.is_cuda else rowptr_cpu
return func.rowptr(self.col[self.csr2csc], self.sparse_size(1))
else:
......
......@@ -214,6 +214,15 @@ class SparseTensor(object):
def detach(self):
return self.from_storage(self.storage.apply(lambda x: x.detach()))
@property
def requires_grad(self):
return self.storage.value.requires_grad if self.has_value() else False
def requires_grad_(self, requires_grad=True):
if self.has_value():
self.storage.value.requires_grad_(requires_grad)
return self
def pin_memory(self):
return self.from_storage(self.storage.apply(lambda x: x.pin_memory()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment