spspmm.py 3.71 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
3
4
from torch_sparse import transpose_matrix, to_scipy, from_scipy

import torch_sparse.spspmm_cpu
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
if torch.cuda.is_available():
7
    import torch_sparse.spspmm_cuda
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def spspmm(indexA, valueA, indexB, valueB, m, k, n):
rusty1s's avatar
docs  
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    """Matrix product of two sparse tensors. Both input sparse matrices need to
    be coalesced.

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
        m (int): The first dimension of first sparse matrix.
        k (int): The second dimension of first sparse matrix and first
            dimension of second sparse matrix.
        n (int): The second dimension of second sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
26
27
    index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
    return index.detach(), value
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
docs  
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
31
    @staticmethod
rusty1s's avatar
rusty1s committed
32
33
34
35
36
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
37
38

    @staticmethod
rusty1s's avatar
rusty1s committed
39
    def backward(ctx, grad_indexC, grad_valueC):
rusty1s's avatar
rusty1s committed
40
41
        m, k = ctx.m, ctx.k
        n = ctx.n
AntoinePrv's avatar
AntoinePrv committed
42
        indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        if not grad_valueC.is_cuda:
            if ctx.needs_input_grad[1] or ctx.needs_input_grad[1]:
                grad_valueC = grad_valueC.clone()

            if ctx.needs_input_grad[1]:
                grad_valueA = torch_sparse.spspmm_cpu.spspmm_bw(
                    indexA, indexC.detach(), grad_valueC, indexB.detach(),
                    valueB, m, k)

            if ctx.needs_input_grad[3]:
                indexA, valueA = transpose_matrix(indexA, valueA, m, k)
                indexC, grad_valueC = transpose_matrix(indexC, grad_valueC, m,
                                                       n)
                grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
                    indexB, indexA.detach(), valueA, indexC.detach(),
                    grad_valueC, k, n)
        else:
            if ctx.needs_input_grad[1]:
                grad_valueA = torch_sparse.spspmm_cuda.spspmm_bw(
                    indexA, indexC.detach(), grad_valueC.clone(),
                    indexB.detach(), valueB, m, k)

            if ctx.needs_input_grad[3]:
                indexA_T, valueA_T = transpose_matrix(indexA, valueA, m, k)
                grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                              grad_valueC, k, m, n)
                grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
73

rusty1s's avatar
rusty1s committed
74
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
75
76


rusty1s's avatar
rusty1s committed
77
78
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
    if indexA.is_cuda:
81
82
        return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
                                               m, k, n)
rusty1s's avatar
rusty1s committed
83

rusty1s's avatar
rusty1s committed
84
85
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
rusty1s's avatar
rusty1s committed
86
87
    C = A.dot(B).tocoo().tocsr().tocoo()  # Force coalesce.
    indexC, valueC = from_scipy(C)
rusty1s's avatar
rusty1s committed
88
    return indexC, valueC
rusty1s's avatar
rusty1s committed
89
90


rusty1s's avatar
rusty1s committed
91
def lift(indexA, valueA, indexB, n):  # pragma: no cover
rusty1s's avatar
bw fix  
rusty1s committed
92
93
    idxA = indexA[0] * n + indexA[1]
    idxB = indexB[0] * n + indexB[1]
rusty1s's avatar
rusty1s committed
94

rusty1s's avatar
bw fix  
rusty1s committed
95
96
    max_value = max(idxA.max().item(), idxB.max().item()) + 1
    valueB = valueA.new_zeros(max_value)
rusty1s's avatar
rusty1s committed
97

rusty1s's avatar
bw fix  
rusty1s committed
98
99
100
101
    valueB[idxA] = valueA
    valueB = valueB[idxB]

    return valueB