spspmm.py 3.15 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch import from_numpy
rusty1s's avatar
rusty1s committed
3
import scipy.sparse
rusty1s's avatar
rusty1s committed
4
from torch_sparse import transpose
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
7
    import spspmm_cuda
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def spspmm(indexA, valueA, indexB, valueB, m, k, n):
rusty1s's avatar
docs  
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    """Matrix product of two sparse tensors. Both input sparse matrices need to
    be coalesced.

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
        m (int): The first dimension of first sparse matrix.
        k (int): The second dimension of first sparse matrix and first
            dimension of second sparse matrix.
        n (int): The second dimension of second sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
rusty1s's avatar
rusty1s committed
26
27
    return SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)

rusty1s's avatar
docs  
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
30
    @staticmethod
rusty1s's avatar
rusty1s committed
31
32
33
34
35
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
36
37

    @staticmethod
rusty1s's avatar
rusty1s committed
38
39
    def backward(ctx, grad_indexC, grad_valueC):
        m, k, n = ctx.m, ctx.k, ctx.n
AntoinePrv's avatar
AntoinePrv committed
40
        indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
45
46
47
48
            indexB_T, valueB_T = transpose(indexB, valueB, k, n)
            grad_indexA, grad_valueA = mm(indexC, grad_valueC, indexB_T,
                                          valueB_T, m, n, k)
            grad_valueA = lift(grad_indexA, grad_valueA, indexA, k)
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
51
52
53
54
        if ctx.needs_input_grad[3]:
            indexA_T, valueA_T = transpose(indexA, valueA, m, k)
            grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                          grad_valueC, k, m, n)
            grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
55

rusty1s's avatar
rusty1s committed
56
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
57
58


rusty1s's avatar
rusty1s committed
59
60
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
61

rusty1s's avatar
rusty1s committed
62
    if indexA.is_cuda:
rusty1s's avatar
rusty1s committed
63
        return spspmm_cuda.spspmm(indexA, valueA, indexB, valueB, m, k, n)
rusty1s's avatar
rusty1s committed
64

rusty1s's avatar
rusty1s committed
65
66
67
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
    indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
rusty1s's avatar
rusty1s committed
68

rusty1s's avatar
rusty1s committed
69
    return indexC, valueC
rusty1s's avatar
rusty1s committed
70
71


rusty1s's avatar
rusty1s committed
72
73
74
def to_scipy(index, value, m, n):
    (row, col), data = index.detach(), value.detach()
    return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
rusty1s's avatar
rusty1s committed
75
76


rusty1s's avatar
rusty1s committed
77
def from_scipy(A):
rusty1s's avatar
rusty1s committed
78
    row, col, value = from_numpy(A.row), from_numpy(A.col), from_numpy(A.data)
rusty1s's avatar
rusty1s committed
79
80
    index = torch.stack([row, col], dim=0).to(torch.long)
    return index, value
rusty1s's avatar
rusty1s committed
81
82


rusty1s's avatar
rusty1s committed
83
def lift(indexA, valueA, indexB, n):  # pragma: no cover
rusty1s's avatar
bw fix  
rusty1s committed
84
85
    idxA = indexA[0] * n + indexA[1]
    idxB = indexB[0] * n + indexB[1]
rusty1s's avatar
rusty1s committed
86

rusty1s's avatar
bw fix  
rusty1s committed
87
88
    max_value = max(idxA.max().item(), idxB.max().item()) + 1
    valueB = valueA.new_zeros(max_value)
rusty1s's avatar
rusty1s committed
89

rusty1s's avatar
bw fix  
rusty1s committed
90
91
92
93
    valueB[idxA] = valueA
    valueB = valueB[idxB]

    return valueB