matmul.py 1.76 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch
from torch import from_numpy
from scipy.sparse import coo_matrix

rusty1s's avatar
to csr  
rusty1s committed
5
from torch_sparse import SparseTensor
rusty1s's avatar
rusty1s committed
6
7
import matmul_cuda

rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
typo  
rusty1s committed
9
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
10
    @staticmethod
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
    def forward(ctx, e1, v1, s1, e2, v2, s2):
        e, v = mm(e1, v1, s1, e2, v2, s2)

        ctx.s1, ctx.s2 = s1, s2
        ctx.save_for_backward(e1, v1, e2, v2, e)

        return e, v
rusty1s's avatar
rusty1s committed
18
19

    @staticmethod
rusty1s's avatar
rusty1s committed
20
21
22
23
    def backward(ctx, grad_e, grad_v):
        e1, v1, e2, v2, e = ctx.saved_variables
        grad_v1 = grad_v2 = None
        grad = (e, grad_v, torch.Size([ctx.s1[0], ctx.s2[1]]))
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
27
        if ctx.needs_input_grad[1]:
            e2 = torch.stack([e2[1], e2[0]], dim=0)
            _, grad_v1 = mm(*grad, e2, v2, torch.Size([ctx.s2[1], ctx.s2[0]]))
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
30
31
        if ctx.needs_input_grad[4]:
            e1 = torch.stack([e1[1], e1[0]], dim=0)
            _, grad_v2 = mm(e1, v1, torch.Size([ctx.s1[1], ctx.s1[0]]), *grad)
rusty1s's avatar
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
        return None, grad_v1, None, None, grad_v2, None
rusty1s's avatar
rusty1s committed
34
35


rusty1s's avatar
typo  
rusty1s committed
36
spspmm = SpSpMM.apply
rusty1s's avatar
rusty1s committed
37
38


rusty1s's avatar
rusty1s committed
39
def mm(e1, v1, s1, e2, v2, s2):
rusty1s's avatar
rusty1s committed
40
41
    if v1.is_cuda:
        return mm_cuda(e1, v1, s1, e2, v2, s2)
rusty1s's avatar
rusty1s committed
42
    else:
rusty1s's avatar
rusty1s committed
43
        return mm_cpu(e1, v1, s1, e2, v2, s2)
rusty1s's avatar
rusty1s committed
44
45


rusty1s's avatar
rusty1s committed
46
def mm_cuda(e1, v1, s1, e2, v2, s2):
rusty1s's avatar
to csr  
rusty1s committed
47
48
    matrix1 = SparseTensor(e1, v1, s1)
    matrix2 = SparseTensor(e2, v2, s2)
rusty1s's avatar
rusty1s committed
49
    return matmul_cuda.spspmm(matrix1, matrix2)
rusty1s's avatar
rusty1s committed
50
51


rusty1s's avatar
rusty1s committed
52
53
54
55
56
def mm_cpu(e1, v1, s1, e2, v2, s2):
    matrix1, matrix2, = to_csr(e1, v1, s1), to_csr(e2, v2, s2)
    out = matrix1.dot(matrix2).tocoo()
    row, col = from_numpy(out.row).long(), from_numpy(out.col).long()
    return torch.stack([row, col], dim=0), from_numpy(out.data)
rusty1s's avatar
rusty1s committed
57
58


rusty1s's avatar
rusty1s committed
59
60
61
62
def to_csr(index, value, size):
    index, value = index.detach().numpy(), value.detach().numpy()
    shape = (size[0], size[1])
    return coo_matrix((value, (index[0], index[1])), shape).tocsr()