spspmm.py 3.01 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch import from_numpy
rusty1s's avatar
rusty1s committed
3
import scipy.sparse
rusty1s's avatar
rusty1s committed
4
from torch_sparse import transpose
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
7
    import spspmm_cuda
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
typo  
rusty1s committed
10
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
docs  
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    """Matrix product of two sparse tensors. Both input sparse matrices need to
    be coalesced.

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
        m (int): The first dimension of first sparse matrix.
        k (int): The second dimension of first sparse matrix and first
            dimension of second sparse matrix.
        n (int): The second dimension of second sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
docs  
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
    @staticmethod
rusty1s's avatar
rusty1s committed
28
29
30
31
32
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
33
34

    @staticmethod
rusty1s's avatar
rusty1s committed
35
36
37
    def backward(ctx, grad_indexC, grad_valueC):
        m, k, n = ctx.m, ctx.k, ctx.n
        indexA, valueA, indexB, valueB, indexC = ctx.saved_variables
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
42
43
44
45
            indexB_T, valueB_T = transpose(indexB, valueB, k, n)
            grad_indexA, grad_valueA = mm(indexC, grad_valueC, indexB_T,
                                          valueB_T, m, n, k)
            grad_valueA = lift(grad_indexA, grad_valueA, indexA, k)
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
49
50
51
        if ctx.needs_input_grad[3]:
            indexA_T, valueA_T = transpose(indexA, valueA, m, k)
            grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                          grad_valueC, k, m, n)
            grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
54
55


rusty1s's avatar
rusty1s committed
56
spspmm = SpSpMM.apply
rusty1s's avatar
rusty1s committed
57
58


rusty1s's avatar
rusty1s committed
59
60
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
61

rusty1s's avatar
rusty1s committed
62
    if indexA.is_cuda:
rusty1s's avatar
rusty1s committed
63
        return spspmm_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n)
rusty1s's avatar
rusty1s committed
64

rusty1s's avatar
rusty1s committed
65
66
67
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
    indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
rusty1s's avatar
rusty1s committed
68

rusty1s's avatar
rusty1s committed
69
    return indexC, valueC
rusty1s's avatar
rusty1s committed
70
71


rusty1s's avatar
rusty1s committed
72
73
74
def to_scipy(index, value, m, n):
    (row, col), data = index.detach(), value.detach()
    return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
rusty1s's avatar
rusty1s committed
75
76


rusty1s's avatar
rusty1s committed
77
def from_scipy(A):
rusty1s's avatar
rusty1s committed
78
    row, col, value = from_numpy(A.row), from_numpy(A.col), from_numpy(A.data)
rusty1s's avatar
rusty1s committed
79
80
    index = torch.stack([row, col], dim=0).to(torch.long)
    return index, value
rusty1s's avatar
rusty1s committed
81
82


rusty1s's avatar
rusty1s committed
83
def lift(indexA, valueA, indexB, n):  # pragma: no cover
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
90
91
    indexA = indexA[0] * n + indexA[1]
    indexB = indexB[0] * n + indexB[1]

    value = valueA.new_zeros(indexB.max().item() + 1)
    value[indexA] = valueA
    value = value[indexB]

    return value