matmul.py 1.51 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
import torch
from torch import from_numpy
from scipy.sparse import coo_matrix


rusty1s's avatar
typo  
rusty1s committed
6
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
7
    @staticmethod
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
    def forward(ctx, e1, v1, s1, e2, v2, s2):
        e, v = mm(e1, v1, s1, e2, v2, s2)

        ctx.s1, ctx.s2 = s1, s2
        ctx.save_for_backward(e1, v1, e2, v2, e)

        return e, v
rusty1s's avatar
rusty1s committed
15
16

    @staticmethod
rusty1s's avatar
rusty1s committed
17
18
19
20
    def backward(ctx, grad_e, grad_v):
        e1, v1, e2, v2, e = ctx.saved_variables
        grad_v1 = grad_v2 = None
        grad = (e, grad_v, torch.Size([ctx.s1[0], ctx.s2[1]]))
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
24
        if ctx.needs_input_grad[1]:
            e2 = torch.stack([e2[1], e2[0]], dim=0)
            _, grad_v1 = mm(*grad, e2, v2, torch.Size([ctx.s2[1], ctx.s2[0]]))
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
28
        if ctx.needs_input_grad[4]:
            e1 = torch.stack([e1[1], e1[0]], dim=0)
            _, grad_v2 = mm(e1, v1, torch.Size([ctx.s1[1], ctx.s1[0]]), *grad)
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
        return None, grad_v1, None, None, grad_v2, None
rusty1s's avatar
rusty1s committed
31
32


rusty1s's avatar
typo  
rusty1s committed
33
spspmm = SpSpMM.apply
rusty1s's avatar
rusty1s committed
34
35


rusty1s's avatar
rusty1s committed
36
37
def mm(e1, v1, s1, e2, v2, s2):
    if e1.is_cuda:
rusty1s's avatar
rusty1s committed
38
39
        pass
    else:
rusty1s's avatar
rusty1s committed
40
        return mm_cpu(e1, v1, s1, e2, v2, s2)
rusty1s's avatar
rusty1s committed
41
42


rusty1s's avatar
rusty1s committed
43
44
45
46
47
def mm_cpu(e1, v1, s1, e2, v2, s2):
    matrix1, matrix2, = to_csr(e1, v1, s1), to_csr(e2, v2, s2)
    out = matrix1.dot(matrix2).tocoo()
    row, col = from_numpy(out.row).long(), from_numpy(out.col).long()
    return torch.stack([row, col], dim=0), from_numpy(out.data)
rusty1s's avatar
rusty1s committed
48
49


rusty1s's avatar
rusty1s committed
50
51
52
53
def to_csr(index, value, size):
    index, value = index.detach().numpy(), value.detach().numpy()
    shape = (size[0], size[1])
    return coo_matrix((value, (index[0], index[1])), shape).tocsr()