spspmm.py 3.26 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch import from_numpy
rusty1s's avatar
rusty1s committed
3
import scipy.sparse
rusty1s's avatar
rusty1s committed
4
from torch_sparse import transpose
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
if torch.cuda.is_available():
7
    import torch_sparse.spspmm_cuda
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def spspmm(indexA, valueA, indexB, valueB, m, k, n):
rusty1s's avatar
docs  
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    """Matrix product of two sparse tensors. Both input sparse matrices need to
    be coalesced.

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
        m (int): The first dimension of first sparse matrix.
        k (int): The second dimension of first sparse matrix and first
            dimension of second sparse matrix.
        n (int): The second dimension of second sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
26
27
    index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
    return index.detach(), value
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
docs  
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
31
    @staticmethod
rusty1s's avatar
rusty1s committed
32
33
34
35
36
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
37
38

    @staticmethod
rusty1s's avatar
rusty1s committed
39
40
    def backward(ctx, grad_indexC, grad_valueC):
        m, k, n = ctx.m, ctx.k, ctx.n
AntoinePrv's avatar
AntoinePrv committed
41
        indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
46
47
48
49
            indexB_T, valueB_T = transpose(indexB, valueB, k, n)
            grad_indexA, grad_valueA = mm(indexC, grad_valueC, indexB_T,
                                          valueB_T, m, n, k)
            grad_valueA = lift(grad_indexA, grad_valueA, indexA, k)
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
53
54
55
        if ctx.needs_input_grad[3]:
            indexA_T, valueA_T = transpose(indexA, valueA, m, k)
            grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                          grad_valueC, k, m, n)
            grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
58
59


rusty1s's avatar
rusty1s committed
60
61
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
62

rusty1s's avatar
rusty1s committed
63
    if indexA.is_cuda:
64
65
        return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
                                               m, k, n)
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
68
69
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
    indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
    return indexC, valueC
rusty1s's avatar
rusty1s committed
72
73


rusty1s's avatar
rusty1s committed
74
75
76
def to_scipy(index, value, m, n):
    (row, col), data = index.detach(), value.detach()
    return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
rusty1s's avatar
rusty1s committed
77
78


rusty1s's avatar
rusty1s committed
79
def from_scipy(A):
rusty1s's avatar
rusty1s committed
80
    row, col, value = from_numpy(A.row), from_numpy(A.col), from_numpy(A.data)
rusty1s's avatar
rusty1s committed
81
82
    index = torch.stack([row, col], dim=0).to(torch.long)
    return index, value
rusty1s's avatar
rusty1s committed
83
84


rusty1s's avatar
rusty1s committed
85
def lift(indexA, valueA, indexB, n):  # pragma: no cover
rusty1s's avatar
bw fix  
rusty1s committed
86
87
    idxA = indexA[0] * n + indexA[1]
    idxB = indexB[0] * n + indexB[1]
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
bw fix  
rusty1s committed
89
90
    max_value = max(idxA.max().item(), idxB.max().item()) + 1
    valueB = valueA.new_zeros(max_value)
rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
bw fix  
rusty1s committed
92
93
94
95
    valueB[idxA] = valueA
    valueB = valueB[idxB]

    return valueB