Commit d92fb90b authored by rusty1s's avatar rusty1s
Browse files

backward pass

parent a339bb4f
from .matmul import sparse_sparse_matmul
__all__ = [
'sparse_sparse_matmul',
]
import torch
from torch import from_numpy
from scipy.sparse import coo_matrix
class SparseSparseMatmul(torch.autograd.Function):
@staticmethod
def forward(ctx, matrix1, matrix2):
ctx.save_for_backawrd(matrix1, matrix2)
return matmul(matrix1, matrix2)
@staticmethod
def backward(ctx, grad_out):
matrix1, matrix2 = ctx.saved_variables
grad_matrix1 = grad_matrix2 = None
if ctx.needs_input_grad[0]:
grad_matrix1 = matmul(grad_out, matrix2.t())
if ctx.needs_input_grad[0]:
grad_matrix2 = matmul(matrix1.t(), grad_out)
return grad_matrix1, grad_matrix2
sparse_sparse_matmul = SparseSparseMatmul.apply
def matmul(A, B):
if A[0].is_cuda:
pass
else:
return matmul_cpu(A, B)
def matmul_cpu(A, B):
A, B, = to_csr(A), to_csr(B)
C = A.dot(B).tocoo()
row, col, value = from_numpy(C.row), from_numpy(C.col), from_numpy(C.data)
return torch.stack([row, col], dim=0), value
def to_csr(A):
(row, col), value, size = A
row, col, value = row.numpy(), col.numpy(), value.numpy()
return coo_matrix((value, (row, col)), shape=(size[0], size[1])).tocsr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment