Commit bd49e20a authored by rusty1s's avatar rusty1s
Browse files

spmm backward implementation

parent df5f7063
...@@ -109,7 +109,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -109,7 +109,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::optional<at::Tensor> arg_out = at::nullopt; at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, -1, rowptr.options()); arg_out = at::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
......
...@@ -2,91 +2,53 @@ from itertools import product ...@@ -2,91 +2,53 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck
from torch_sparse.matmul import matmul from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
import torch_scatter import torch_scatter
from .utils import tensor, devices, dtypes from .utils import devices, grad_dtypes
devices = ['cpu'] devices = ['cpu']
dtypes = [torch.float] grad_dtypes = [torch.float]
reductions = ['sum', 'mean', 'min', 'max'] reductions = ['sum', 'mean', 'min', 'max']
# grad_reductions = ['sum', 'mean'] reductions = ['min']
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_forward(dtype, device):
src_dense = torch.randn((5, 4), dtype=dtype, device=device)
src = SparseTensor.from_dense(src_dense)
src.requires_grad_()
src_dense = src_dense.clone().requires_grad_()
other = torch.randn((4, 8), dtype=dtype, device=device)
other.requires_grad_()
out1 = matmul(src, other)
grad_out = torch.randn_like(out1)
out1.backward(grad_out)
other.grad = None
out2 = torch.matmul(src_dense, other)
out2.backward(grad_out)
# assert torch.allclose(out1, out2)
# assert torch.allclose(src.storage.value.grad.view(5, 4), src_dense.grad)
@pytest.mark.parametrize('dtype,device,reduce', @pytest.mark.parametrize('dtype,device,reduce',
product(dtypes, devices, reductions)) product(grad_dtypes, devices, reductions))
def test_spmm(dtype, device, reduce): def test_spmm(dtype, device, reduce):
src = torch.ones((5, 4), dtype=dtype, device=device) src = torch.randn((10, 8), dtype=dtype, device=device)
src[2, :] = 0 # Delete one row...
src[2] = 0 src[:, 2:4] = 0 # Delete one col...
src = SparseTensor.from_dense(src).requires_grad_() src = SparseTensor.from_dense(src).requires_grad_()
src.set_value_(None) (row, col), value = src.coo()
other = torch.randn((2, 4, 2), dtype=dtype, device=device, other = torch.randn((2, 8, 2), dtype=dtype, device=device,
requires_grad=True) requires_grad=True)
(row, col), value = src.coo() src_col = other.index_select(-2, col) * value.unsqueeze(-1)
out1 = other.index_select(-2, col) # * value.unsqueeze(-1)
func = 'add' if reduce == 'sum' else reduce func = 'add' if reduce == 'sum' else reduce
out1 = getattr(torch_scatter, f'scatter_{func}')(out1, row, dim=-2) expected = getattr(torch_scatter, f'scatter_{func}')(src_col, row, dim=-2)
out1 = out1[0] if isinstance(out1, tuple) else out1 expected = expected[0] if isinstance(expected, tuple) else expected
if reduce == 'min':
grad_out = torch.randn_like(out1) expected[expected > 1000] = 0
out1.backward(grad_out) if reduce == 'max':
# grad_value1 = value.grad expected[expected < 1000] = 0
# value.grad = None
grad_other1 = other.grad grad_out = torch.randn_like(expected)
expected.backward(grad_out)
expected_grad_value = value.grad
value.grad = None
expected_grad_other = other.grad
other.grad = None other.grad = None
print(reduce) out = matmul(src, other, reduce)
out2 = matmul(src, other, reduce) out = out[0] if isinstance(out, tuple) else out
out2 = out2[0] if isinstance(out2, tuple) else out2 out.backward(grad_out)
out2.backward(grad_out)
# grad_value2 = value.grad
# value.grad = None
grad_other2 = other.grad
other.grad = None
# assert torch.allclose(out1, out2)
# assert torch.allclose(grad_value1, grad_value2)
assert torch.allclose(grad_other1, grad_other2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_backward(dtype, device):
src_dense = torch.randn((5, 4), dtype=torch.double, device=device)
src = SparseTensor.from_dense(src_dense)
src.requires_grad_()
other = torch.randn((4, 8), dtype=torch.double, device=device)
other.requires_grad_()
# assert gradcheck(matmul, (src, other, "sum")) assert torch.allclose(expected, out)
assert torch.allclose(expected_grad_value, value.grad)
assert torch.allclose(expected_grad_other, other.grad)
...@@ -34,6 +34,12 @@ class SPMM(torch.autograd.Function): ...@@ -34,6 +34,12 @@ class SPMM(torch.autograd.Function):
data = ctx.saved_tensors data = ctx.saved_tensors
index, rowcount, rowptr, colptr, csr2csc, value, mat, arg_out = data index, rowcount, rowptr, colptr, csr2csc, value, mat, arg_out = data
invalid_arg_mask = arg_out_ind = None
if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5]
or ctx.needs_input_grad[6]):
invalid_arg_mask = arg_out == index.size(1)
arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
grad_value = None grad_value = None
if ctx.needs_input_grad[5]: if ctx.needs_input_grad[5]:
if ctx.reduce in ['sum', 'add']: if ctx.reduce in ['sum', 'add']:
...@@ -45,12 +51,12 @@ class SPMM(torch.autograd.Function): ...@@ -45,12 +51,12 @@ class SPMM(torch.autograd.Function):
rowptr, index[1], mat, grad_out, ctx.reduce) rowptr, index[1], mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']: elif ctx.reduce in ['min', 'max']:
col = index[1][arg_out.flatten()].view_as(arg_out) col = index[1][arg_out_ind.flatten()].view_as(arg_out)
out = mat.gather(-2, col).mul_(grad_out) out = mat.gather(-2, col).mul_(grad_out)
out.masked_fill_(arg_out == -1, 0) out.masked_fill_(invalid_arg_mask, 0)
col = col.add_(rowptr[:-1].view(-1, 1)) grad_value = scatter_add(out.flatten(), arg_out.flatten(),
grad_value = scatter_add(out.flatten(), col.flatten(), dim=0, dim=0, dim_size=value.numel() + 1)
dim_size=value.numel()) grad_value = grad_value[:-1]
grad_mat = None grad_mat = None
if ctx.needs_input_grad[6]: if ctx.needs_input_grad[6]:
...@@ -70,12 +76,12 @@ class SPMM(torch.autograd.Function): ...@@ -70,12 +76,12 @@ class SPMM(torch.autograd.Function):
elif ctx.reduce in ['min', 'max']: elif ctx.reduce in ['min', 'max']:
if value is not None: if value is not None:
value = value[arg_out.flatten()].view_as(arg_out) value = value[arg_out_ind.flatten()].view_as(arg_out)
value = value.mul_(grad_out) value = value.mul_(grad_out)
else: else:
value = grad_out value = grad_out
value.masked_fill_(arg_out == -1, 0) value.masked_fill_(invalid_arg_mask, 0)
col = index[1][arg_out.flatten()].view_as(arg_out) col = index[1][arg_out_ind.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col, dim=-2, grad_mat = scatter_add(value, col, dim=-2,
dim_size=mat.size(-2)) dim_size=mat.size(-2))
...@@ -89,14 +95,14 @@ def matmul(src, other, reduce='sum'): ...@@ -89,14 +95,14 @@ def matmul(src, other, reduce='sum'):
assert reduce in ['sum', 'add', 'mean', 'min', 'max'] assert reduce in ['sum', 'add', 'mean', 'min', 'max']
(index, value), rowptr = src.coo(), src.storage.rowptr (index, value), rowptr = src.coo(), src.storage.rowptr
csr2csc = colptr = None
if other.requires_grad and reduce in ['sum', 'add', 'mean']:
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
rowcount = None rowcount = None
if other.requires_grad and reduce in ['mean']: if other.requires_grad and reduce in ['mean']:
rowcount = src.storage.rowcount rowcount = src.storage.rowcount
csr2csc = colptr = None
if other.requires_grad and reduce in ['sum', 'add', 'mean']:
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
return SPMM.apply(index, rowcount, rowptr, colptr, csr2csc, value, return SPMM.apply(index, rowcount, rowptr, colptr, csr2csc, value,
other, reduce) other, reduce)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment