Commit 0ae0e784 authored by rusty1s's avatar rusty1s
Browse files

backward implementation

parent 572227be
import torch
from torch_sparse import spspmm
def test_spspmm():
e1 = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]])
v1 = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float, requires_grad=True)
matrix1 = (e1, v1, torch.Size([3, 3]))
e2 = torch.tensor([[0, 2], [1, 0]])
v2 = torch.tensor([2, 4], dtype=torch.float, requires_grad=True)
matrix2 = (e2, v2, torch.Size([3, 2]))
index, value = spspmm(*matrix1, *matrix2)
out = torch.sparse.FloatTensor(index, value, torch.Size([3, 2])).to_dense()
assert out.tolist() == [[8, 0], [0, 6], [0, 8]]
value.sum().backward()
...@@ -5,42 +5,49 @@ from scipy.sparse import coo_matrix ...@@ -5,42 +5,49 @@ from scipy.sparse import coo_matrix
class SpSpMM(torch.autograd.Function): class SpSpMM(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, matrix1, matrix2): def forward(ctx, e1, v1, s1, e2, v2, s2):
ctx.save_for_backawrd(matrix1, matrix2) e, v = mm(e1, v1, s1, e2, v2, s2)
return mm(matrix1, matrix2)
ctx.s1, ctx.s2 = s1, s2
ctx.save_for_backward(e1, v1, e2, v2, e)
return e, v
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_e, grad_v):
matrix1, matrix2 = ctx.saved_variables e1, v1, e2, v2, e = ctx.saved_variables
grad_matrix1 = grad_matrix2 = None grad_v1 = grad_v2 = None
grad = (e, grad_v, torch.Size([ctx.s1[0], ctx.s2[1]]))
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[1]:
grad_matrix1 = mm(grad_out, matrix2.t()) e2 = torch.stack([e2[1], e2[0]], dim=0)
_, grad_v1 = mm(*grad, e2, v2, torch.Size([ctx.s2[1], ctx.s2[0]]))
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[4]:
grad_matrix2 = mm(matrix1.t(), grad_out) e1 = torch.stack([e1[1], e1[0]], dim=0)
_, grad_v2 = mm(e1, v1, torch.Size([ctx.s1[1], ctx.s1[0]]), *grad)
return grad_matrix1, grad_matrix2 return None, grad_v1, None, None, grad_v2, None
spspmm = SpSpMM.apply spspmm = SpSpMM.apply
def mm(A, B): def mm(e1, v1, s1, e2, v2, s2):
if A[0].is_cuda: if e1.is_cuda:
pass pass
else: else:
return mm_cpu(A, B) return mm_cpu(e1, v1, s1, e2, v2, s2)
def mm_cpu(A, B): def mm_cpu(e1, v1, s1, e2, v2, s2):
A, B, = to_csr(A), to_csr(B) matrix1, matrix2, = to_csr(e1, v1, s1), to_csr(e2, v2, s2)
C = A.dot(B).tocoo() out = matrix1.dot(matrix2).tocoo()
row, col, value = from_numpy(C.row), from_numpy(C.col), from_numpy(C.data) row, col = from_numpy(out.row).long(), from_numpy(out.col).long()
return torch.stack([row, col], dim=0), value return torch.stack([row, col], dim=0), from_numpy(out.data)
def to_csr(A): def to_csr(index, value, size):
(row, col), value, size = A index, value = index.detach().numpy(), value.detach().numpy()
row, col, value = row.numpy(), col.numpy(), value.numpy() shape = (size[0], size[1])
return coo_matrix((value, (row, col)), shape=(size[0], size[1])).tocsr() return coo_matrix((value, (index[0], index[1])), shape).tocsr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment