Commit 572227be authored by rusty1s's avatar rusty1s
Browse files

typo

parent d92fb90b
from .matmul import sparse_sparse_matmul
from .matmul import spspmm
__all__ = [
'sparse_sparse_matmul',
'spspmm',
]
......@@ -3,11 +3,11 @@ from torch import from_numpy
from scipy.sparse import coo_matrix
class SparseSparseMatmul(torch.autograd.Function):
class SpSpMM(torch.autograd.Function):
@staticmethod
def forward(ctx, matrix1, matrix2):
ctx.save_for_backawrd(matrix1, matrix2)
return matmul(matrix1, matrix2)
return mm(matrix1, matrix2)
@staticmethod
def backward(ctx, grad_out):
......@@ -15,25 +15,25 @@ class SparseSparseMatmul(torch.autograd.Function):
grad_matrix1 = grad_matrix2 = None
if ctx.needs_input_grad[0]:
grad_matrix1 = matmul(grad_out, matrix2.t())
grad_matrix1 = mm(grad_out, matrix2.t())
if ctx.needs_input_grad[0]:
grad_matrix2 = matmul(matrix1.t(), grad_out)
grad_matrix2 = mm(matrix1.t(), grad_out)
return grad_matrix1, grad_matrix2
sparse_sparse_matmul = SparseSparseMatmul.apply
spspmm = SpSpMM.apply
def matmul(A, B):
def mm(A, B):
if A[0].is_cuda:
pass
else:
return matmul_cpu(A, B)
return mm_cpu(A, B)
def matmul_cpu(A, B):
def mm_cpu(A, B):
A, B, = to_csr(A), to_csr(B)
C = A.dot(B).tocoo()
row, col, value = from_numpy(C.row), from_numpy(C.col), from_numpy(C.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment