matmul.py 2.38 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
3
import scipy.sparse
from torch_sparse import transpose
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
if torch.cuda.is_available():
    import matmul_cuda


def spspmm(indexA, valueA, sizeA, indexB, valueB, sizeB):
    assert valueA.dtype == valueB.dtype
    assert len(sizeA) == len(sizeB) == 2
    assert sizeA[1] == sizeB[0]

    index, value = SpSpMM.apply(indexA, valueA, sizeA, indexB, valueB, sizeB)
    size = torch.Size([sizeA[0], sizeB[1]])

    return index, value, size
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
typo  
rusty1s committed
20
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
21
    @staticmethod
rusty1s's avatar
rusty1s committed
22
23
    def forward(ctx, indexA, valueA, sizeA, indexB, valueB, sizeB):
        index, value = mm(indexA, valueA, sizeA, indexB, valueB, sizeB)
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
        ctx.sizeA, ctx.sizeB = sizeA, sizeB
        ctx.save_for_backward(indexA, valueA, indexB, valueB, index)
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
        return index, value
rusty1s's avatar
rusty1s committed
29
30

    @staticmethod
rusty1s's avatar
rusty1s committed
31
32
33
34
    def backward(ctx, grad_index, grad_value):
        indexA, valueA, indexB, valueB, index = ctx.saved_variables
        grad_valueA = grad_valueB = None
        grad = (index, grad_value, torch.Size([ctx.sizeA[0], ctx.sizeB[1]]))
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
37
38
            B_tranposed = transpose(indexB, valueB, ctx.sizeB)
            _, grad_valueA = mm(*grad, *B_tranposed)
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
        if ctx.needs_input_grad[4]:
rusty1s's avatar
rusty1s committed
41
42
            A_tranposed = transpose(indexA, valueA, ctx.sizeA)
            _, grad_valueB = mm(*A_tranposed, *grad)
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
        return None, grad_valueA, None, None, grad_valueB, None
rusty1s's avatar
rusty1s committed
45
46


rusty1s's avatar
rusty1s committed
47
48
49
50
51
def mm(indexA, valueA, sizeA, indexB, valueB, sizeB):
    if valueA.is_cuda:
        return mm_cuda(indexA, valueA, sizeA, indexB, valueB, sizeB)
    else:
        return mm_cpu(indexA, valueA, sizeA, indexB, valueB, sizeB)
rusty1s's avatar
rusty1s committed
52
53


rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
60
def mm_cuda(indexA, valueA, sizeA, indexB, valueB, sizeB):
    A = torch.sparse_coo_tensor(indexA, valueA, sizeA)
    B = torch.sparse_coo_tensor(indexB, valueB, sizeB)

    index, value = matmul_cuda.spspmm(A, B)

    return index, value
rusty1s's avatar
rusty1s committed
61
62


rusty1s's avatar
rusty1s committed
63
64
65
def mm_cpu(indexA, valueA, sizeA, indexB, valueB, sizeB):
    A, B, = to_scipy(indexA, valueA, sizeA), to_scipy(indexB, valueB, sizeB)
    C = A.tocsr().dot(B.tocsr()).tocoo()
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
68
69
    row, col = torch.from_numpy(C.row).long(), torch.from_numpy(C.col).long()
    index = torch.stack([row, col], dim=0)
    value = torch.from_numpy(C.data).type_as(valueA)
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
    return index, value
rusty1s's avatar
rusty1s committed
72
73


rusty1s's avatar
rusty1s committed
74
75
76
def to_scipy(index, value, size):
    (row, col), value = index.detach().numpy(), value.detach().numpy()
    return scipy.sparse.coo_matrix((value, (row, col)), tuple(size))