Commit bd49e20a authored by rusty1s's avatar rusty1s
Browse files

spmm backward implementation

parent df5f7063
......@@ -109,7 +109,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, -1, rowptr.options());
arg_out = at::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
......
......@@ -2,91 +2,53 @@ from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor
import torch_scatter
from .utils import tensor, devices, dtypes
from .utils import devices, grad_dtypes
devices = ['cpu']
dtypes = [torch.float]
grad_dtypes = [torch.float]
reductions = ['sum', 'mean', 'min', 'max']
# grad_reductions = ['sum', 'mean']
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_forward(dtype, device):
src_dense = torch.randn((5, 4), dtype=dtype, device=device)
src = SparseTensor.from_dense(src_dense)
src.requires_grad_()
src_dense = src_dense.clone().requires_grad_()
other = torch.randn((4, 8), dtype=dtype, device=device)
other.requires_grad_()
out1 = matmul(src, other)
grad_out = torch.randn_like(out1)
out1.backward(grad_out)
other.grad = None
out2 = torch.matmul(src_dense, other)
out2.backward(grad_out)
# assert torch.allclose(out1, out2)
# assert torch.allclose(src.storage.value.grad.view(5, 4), src_dense.grad)
reductions = ['min']
@pytest.mark.parametrize('dtype,device,reduce',
product(dtypes, devices, reductions))
product(grad_dtypes, devices, reductions))
def test_spmm(dtype, device, reduce):
src = torch.ones((5, 4), dtype=dtype, device=device)
src[2] = 0
src = torch.randn((10, 8), dtype=dtype, device=device)
src[2, :] = 0 # Delete one row...
src[:, 2:4] = 0 # Delete one col...
src = SparseTensor.from_dense(src).requires_grad_()
src.set_value_(None)
(row, col), value = src.coo()
other = torch.randn((2, 4, 2), dtype=dtype, device=device,
other = torch.randn((2, 8, 2), dtype=dtype, device=device,
requires_grad=True)
(row, col), value = src.coo()
out1 = other.index_select(-2, col) # * value.unsqueeze(-1)
src_col = other.index_select(-2, col) * value.unsqueeze(-1)
func = 'add' if reduce == 'sum' else reduce
out1 = getattr(torch_scatter, f'scatter_{func}')(out1, row, dim=-2)
out1 = out1[0] if isinstance(out1, tuple) else out1
grad_out = torch.randn_like(out1)
out1.backward(grad_out)
# grad_value1 = value.grad
# value.grad = None
grad_other1 = other.grad
expected = getattr(torch_scatter, f'scatter_{func}')(src_col, row, dim=-2)
expected = expected[0] if isinstance(expected, tuple) else expected
if reduce == 'min':
expected[expected > 1000] = 0
if reduce == 'max':
expected[expected < 1000] = 0
grad_out = torch.randn_like(expected)
expected.backward(grad_out)
expected_grad_value = value.grad
value.grad = None
expected_grad_other = other.grad
other.grad = None
print(reduce)
out2 = matmul(src, other, reduce)
out2 = out2[0] if isinstance(out2, tuple) else out2
out2.backward(grad_out)
# grad_value2 = value.grad
# value.grad = None
grad_other2 = other.grad
other.grad = None
# assert torch.allclose(out1, out2)
# assert torch.allclose(grad_value1, grad_value2)
assert torch.allclose(grad_other1, grad_other2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_backward(dtype, device):
src_dense = torch.randn((5, 4), dtype=torch.double, device=device)
src = SparseTensor.from_dense(src_dense)
src.requires_grad_()
other = torch.randn((4, 8), dtype=torch.double, device=device)
other.requires_grad_()
out = matmul(src, other, reduce)
out = out[0] if isinstance(out, tuple) else out
out.backward(grad_out)
# assert gradcheck(matmul, (src, other, "sum"))
assert torch.allclose(expected, out)
assert torch.allclose(expected_grad_value, value.grad)
assert torch.allclose(expected_grad_other, other.grad)
......@@ -34,6 +34,12 @@ class SPMM(torch.autograd.Function):
data = ctx.saved_tensors
index, rowcount, rowptr, colptr, csr2csc, value, mat, arg_out = data
invalid_arg_mask = arg_out_ind = None
if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5]
or ctx.needs_input_grad[6]):
invalid_arg_mask = arg_out == index.size(1)
arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
grad_value = None
if ctx.needs_input_grad[5]:
if ctx.reduce in ['sum', 'add']:
......@@ -45,12 +51,12 @@ class SPMM(torch.autograd.Function):
rowptr, index[1], mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']:
col = index[1][arg_out.flatten()].view_as(arg_out)
col = index[1][arg_out_ind.flatten()].view_as(arg_out)
out = mat.gather(-2, col).mul_(grad_out)
out.masked_fill_(arg_out == -1, 0)
col = col.add_(rowptr[:-1].view(-1, 1))
grad_value = scatter_add(out.flatten(), col.flatten(), dim=0,
dim_size=value.numel())
out.masked_fill_(invalid_arg_mask, 0)
grad_value = scatter_add(out.flatten(), arg_out.flatten(),
dim=0, dim_size=value.numel() + 1)
grad_value = grad_value[:-1]
grad_mat = None
if ctx.needs_input_grad[6]:
......@@ -70,12 +76,12 @@ class SPMM(torch.autograd.Function):
elif ctx.reduce in ['min', 'max']:
if value is not None:
value = value[arg_out.flatten()].view_as(arg_out)
value = value[arg_out_ind.flatten()].view_as(arg_out)
value = value.mul_(grad_out)
else:
value = grad_out
value.masked_fill_(arg_out == -1, 0)
col = index[1][arg_out.flatten()].view_as(arg_out)
value.masked_fill_(invalid_arg_mask, 0)
col = index[1][arg_out_ind.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col, dim=-2,
dim_size=mat.size(-2))
......@@ -89,14 +95,14 @@ def matmul(src, other, reduce='sum'):
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
(index, value), rowptr = src.coo(), src.storage.rowptr
csr2csc = colptr = None
if other.requires_grad and reduce in ['sum', 'add', 'mean']:
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
rowcount = None
if other.requires_grad and reduce in ['mean']:
rowcount = src.storage.rowcount
csr2csc = colptr = None
if other.requires_grad and reduce in ['sum', 'add', 'mean']:
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
return SPMM.apply(index, rowcount, rowptr, colptr, csr2csc, value,
other, reduce)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment