spspmm.py 2.36 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch import from_numpy
rusty1s's avatar
rusty1s committed
3
import scipy.sparse
rusty1s's avatar
rusty1s committed
4
from torch_sparse import transpose
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
7
    import spspmm_cuda
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
typo  
rusty1s committed
10
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
11
    """Sparse matrix product of two sparse matrices with autograd support."""
rusty1s's avatar
docs  
rusty1s committed
12

rusty1s's avatar
rusty1s committed
13
    @staticmethod
rusty1s's avatar
rusty1s committed
14
15
16
17
18
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
19
20

    @staticmethod
rusty1s's avatar
rusty1s committed
21
22
23
    def backward(ctx, grad_indexC, grad_valueC):
        m, k, n = ctx.m, ctx.k, ctx.n
        indexA, valueA, indexB, valueB, indexC = ctx.saved_variables
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
28
29
30
31
            indexB_T, valueB_T = transpose(indexB, valueB, k, n)
            grad_indexA, grad_valueA = mm(indexC, grad_valueC, indexB_T,
                                          valueB_T, m, n, k)
            grad_valueA = lift(grad_indexA, grad_valueA, indexA, k)
rusty1s's avatar
rusty1s committed
32

rusty1s's avatar
rusty1s committed
33
34
35
36
37
        if ctx.needs_input_grad[3]:
            indexA_T, valueA_T = transpose(indexA, valueA, m, k)
            grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                          grad_valueC, k, m, n)
            grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
40
41


rusty1s's avatar
rusty1s committed
42
spspmm = SpSpMM.apply
rusty1s's avatar
rusty1s committed
43
44


rusty1s's avatar
rusty1s committed
45
46
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
    if indexA.is_cuda:
rusty1s's avatar
rusty1s committed
49
        return spspmm_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n)
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
53
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
    indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
rusty1s's avatar
rusty1s committed
54

rusty1s's avatar
rusty1s committed
55
    return indexC, valueC
rusty1s's avatar
rusty1s committed
56
57


rusty1s's avatar
rusty1s committed
58
59
60
def to_scipy(index, value, m, n):
    (row, col), data = index.detach(), value.detach()
    return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
rusty1s's avatar
rusty1s committed
61
62


rusty1s's avatar
rusty1s committed
63
def from_scipy(A):
rusty1s's avatar
rusty1s committed
64
    row, col, value = from_numpy(A.row), from_numpy(A.col), from_numpy(A.data)
rusty1s's avatar
rusty1s committed
65
66
    index = torch.stack([row, col], dim=0).to(torch.long)
    return index, value
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74
75
76
77


def lift(indexA, valueA, indexB, n):
    indexA = indexA[0] * n + indexA[1]
    indexB = indexB[0] * n + indexB[1]

    value = valueA.new_zeros(indexB.max().item() + 1)
    value[indexA] = valueA
    value = value[indexB]

    return value