Commit 572227be authored by rusty1s's avatar rusty1s
Browse files

typo

parent d92fb90b
from .matmul import sparse_sparse_matmul from .matmul import spspmm
__all__ = [ __all__ = [
'sparse_sparse_matmul', 'spspmm',
] ]
...@@ -3,11 +3,11 @@ from torch import from_numpy ...@@ -3,11 +3,11 @@ from torch import from_numpy
from scipy.sparse import coo_matrix from scipy.sparse import coo_matrix
class SparseSparseMatmul(torch.autograd.Function): class SpSpMM(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, matrix1, matrix2): def forward(ctx, matrix1, matrix2):
ctx.save_for_backawrd(matrix1, matrix2) ctx.save_for_backawrd(matrix1, matrix2)
return matmul(matrix1, matrix2) return mm(matrix1, matrix2)
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
...@@ -15,25 +15,25 @@ class SparseSparseMatmul(torch.autograd.Function): ...@@ -15,25 +15,25 @@ class SparseSparseMatmul(torch.autograd.Function):
grad_matrix1 = grad_matrix2 = None grad_matrix1 = grad_matrix2 = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_matrix1 = matmul(grad_out, matrix2.t()) grad_matrix1 = mm(grad_out, matrix2.t())
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_matrix2 = matmul(matrix1.t(), grad_out) grad_matrix2 = mm(matrix1.t(), grad_out)
return grad_matrix1, grad_matrix2 return grad_matrix1, grad_matrix2
sparse_sparse_matmul = SparseSparseMatmul.apply spspmm = SpSpMM.apply
def matmul(A, B): def mm(A, B):
if A[0].is_cuda: if A[0].is_cuda:
pass pass
else: else:
return matmul_cpu(A, B) return mm_cpu(A, B)
def matmul_cpu(A, B): def mm_cpu(A, B):
A, B, = to_csr(A), to_csr(B) A, B, = to_csr(A), to_csr(B)
C = A.dot(B).tocoo() C = A.dot(B).tocoo()
row, col, value = from_numpy(C.row), from_numpy(C.col), from_numpy(C.data) row, col, value = from_numpy(C.row), from_numpy(C.col), from_numpy(C.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment