matmul.py 1.11 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
import torch
from torch import from_numpy
from scipy.sparse import coo_matrix


rusty1s's avatar
typo  
rusty1s committed
6
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
7
8
9
    @staticmethod
    def forward(ctx, matrix1, matrix2):
        ctx.save_for_backawrd(matrix1, matrix2)
rusty1s's avatar
typo  
rusty1s committed
10
        return mm(matrix1, matrix2)
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17

    @staticmethod
    def backward(ctx, grad_out):
        matrix1, matrix2 = ctx.saved_variables
        grad_matrix1 = grad_matrix2 = None

        if ctx.needs_input_grad[0]:
rusty1s's avatar
typo  
rusty1s committed
18
            grad_matrix1 = mm(grad_out, matrix2.t())
rusty1s's avatar
rusty1s committed
19
20

        if ctx.needs_input_grad[0]:
rusty1s's avatar
typo  
rusty1s committed
21
            grad_matrix2 = mm(matrix1.t(), grad_out)
rusty1s's avatar
rusty1s committed
22
23
24
25

        return grad_matrix1, grad_matrix2


rusty1s's avatar
typo  
rusty1s committed
26
spspmm = SpSpMM.apply
rusty1s's avatar
rusty1s committed
27
28


rusty1s's avatar
typo  
rusty1s committed
29
def mm(A, B):
rusty1s's avatar
rusty1s committed
30
31
32
    if A[0].is_cuda:
        pass
    else:
rusty1s's avatar
typo  
rusty1s committed
33
        return mm_cpu(A, B)
rusty1s's avatar
rusty1s committed
34
35


rusty1s's avatar
typo  
rusty1s committed
36
def mm_cpu(A, B):
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43
44
45
46
    A, B, = to_csr(A), to_csr(B)
    C = A.dot(B).tocoo()
    row, col, value = from_numpy(C.row), from_numpy(C.col), from_numpy(C.data)
    return torch.stack([row, col], dim=0), value


def to_csr(A):
    (row, col), value, size = A
    row, col, value = row.numpy(), col.numpy(), value.numpy()
    return coo_matrix((value, (row, col)), shape=(size[0], size[1])).tocsr()