spspmm.py 3.33 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch import from_numpy
rusty1s's avatar
rusty1s committed
3
import numpy as np
rusty1s's avatar
rusty1s committed
4
import scipy.sparse
rusty1s's avatar
rusty1s committed
5
from torch_sparse import transpose
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
if torch.cuda.is_available():
8
    import torch_sparse.spspmm_cuda
rusty1s's avatar
rusty1s committed
9
10


rusty1s's avatar
rusty1s committed
11
def spspmm(indexA, valueA, indexB, valueB, m, k, n):
rusty1s's avatar
docs  
rusty1s committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
    """Matrix product of two sparse tensors. Both input sparse matrices need to
    be coalesced.

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
        m (int): The first dimension of first sparse matrix.
        k (int): The second dimension of first sparse matrix and first
            dimension of second sparse matrix.
        n (int): The second dimension of second sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
27
28
    index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
    return index.detach(), value
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
docs  
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
32
    @staticmethod
rusty1s's avatar
rusty1s committed
33
34
35
36
37
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
38
39

    @staticmethod
rusty1s's avatar
rusty1s committed
40
41
    def backward(ctx, grad_indexC, grad_valueC):
        m, k, n = ctx.m, ctx.k, ctx.n
AntoinePrv's avatar
AntoinePrv committed
42
        indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
rusty1s committed
46
        if ctx.needs_input_grad[1]:
rusty1s's avatar
rusty1s committed
47
48
49
50
            indexB_T, valueB_T = transpose(indexB, valueB, k, n)
            grad_indexA, grad_valueA = mm(indexC, grad_valueC, indexB_T,
                                          valueB_T, m, n, k)
            grad_valueA = lift(grad_indexA, grad_valueA, indexA, k)
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
53
54
55
56
        if ctx.needs_input_grad[3]:
            indexA_T, valueA_T = transpose(indexA, valueA, m, k)
            grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                          grad_valueC, k, m, n)
            grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
59
60


rusty1s's avatar
rusty1s committed
61
62
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
    if indexA.is_cuda:
65
66
        return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
                                               m, k, n)
rusty1s's avatar
rusty1s committed
67

rusty1s's avatar
rusty1s committed
68
69
70
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
    indexC, valueC = from_scipy(A.tocsr().dot(B.tocsr()).tocoo())
rusty1s's avatar
rusty1s committed
71

rusty1s's avatar
rusty1s committed
72
    return indexC, valueC
rusty1s's avatar
rusty1s committed
73
74


rusty1s's avatar
rusty1s committed
75
76
77
def to_scipy(index, value, m, n):
    (row, col), data = index.detach(), value.detach()
    return scipy.sparse.coo_matrix((data, (row, col)), (m, n))
rusty1s's avatar
rusty1s committed
78
79


rusty1s's avatar
rusty1s committed
80
def from_scipy(A):
rusty1s's avatar
rusty1s committed
81
82
83
    row, col, value = A.row.astype(np.int64), A.col.astype(np.int64), A.data
    row, col, value = from_numpy(row), from_numpy(col), from_numpy(value)
    index = torch.stack([row, col], dim=0)
rusty1s's avatar
rusty1s committed
84
    return index, value
rusty1s's avatar
rusty1s committed
85
86


rusty1s's avatar
rusty1s committed
87
def lift(indexA, valueA, indexB, n):  # pragma: no cover
rusty1s's avatar
bw fix  
rusty1s committed
88
89
    idxA = indexA[0] * n + indexA[1]
    idxB = indexB[0] * n + indexB[1]
rusty1s's avatar
rusty1s committed
90

rusty1s's avatar
bw fix  
rusty1s committed
91
92
    max_value = max(idxA.max().item(), idxB.max().item()) + 1
    valueB = valueA.new_zeros(max_value)
rusty1s's avatar
rusty1s committed
93

rusty1s's avatar
bw fix  
rusty1s committed
94
95
96
97
    valueB[idxA] = valueA
    valueB = valueB[idxB]

    return valueB