spspmm.py 3.63 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
2
from torch_sparse import transpose, to_scipy, from_scipy
rusty1s's avatar
rusty1s committed
3
4

import torch_sparse.spspmm_cpu
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
if torch.cuda.is_available():
7
    import torch_sparse.spspmm_cuda
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def spspmm(indexA, valueA, indexB, valueB, m, k, n):
rusty1s's avatar
docs  
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    """Matrix product of two sparse tensors. Both input sparse matrices need to
    be coalesced.

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
        m (int): The first dimension of first sparse matrix.
        k (int): The second dimension of first sparse matrix and first
            dimension of second sparse matrix.
        n (int): The second dimension of second sparse matrix.

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
26
27
    index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
    return index.detach(), value
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
docs  
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
31
    @staticmethod
rusty1s's avatar
rusty1s committed
32
33
34
35
36
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
37
38

    @staticmethod
rusty1s's avatar
rusty1s committed
39
    def backward(ctx, grad_indexC, grad_valueC):
rusty1s's avatar
rusty1s committed
40
41
        m, k = ctx.m, ctx.k
        n = ctx.n
AntoinePrv's avatar
AntoinePrv committed
42
        indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
52
53
54
55
        if not grad_valueC.is_cuda:
            if ctx.needs_input_grad[1] or ctx.needs_input_grad[1]:
                grad_valueC = grad_valueC.clone()

            if ctx.needs_input_grad[1]:
                grad_valueA = torch_sparse.spspmm_cpu.spspmm_bw(
                    indexA, indexC.detach(), grad_valueC, indexB.detach(),
                    valueB, m, k)

            if ctx.needs_input_grad[3]:
56
57
                indexA, valueA = transpose(indexA, valueA, m, k)
                indexC, grad_valueC = transpose(indexC, grad_valueC, m, n)
rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
                grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
                    indexB, indexA.detach(), valueA, indexC.detach(),
                    grad_valueC, k, n)
        else:
            if ctx.needs_input_grad[1]:
                grad_valueA = torch_sparse.spspmm_cuda.spspmm_bw(
                    indexA, indexC.detach(), grad_valueC.clone(),
                    indexB.detach(), valueB, m, k)

            if ctx.needs_input_grad[3]:
68
                indexA_T, valueA_T = transpose(indexA, valueA, m, k)
rusty1s's avatar
rusty1s committed
69
70
71
                grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                              grad_valueC, k, m, n)
                grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
74
75


rusty1s's avatar
rusty1s committed
76
77
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79
    if indexA.is_cuda:
80
81
        return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
                                               m, k, n)
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
84
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
rusty1s's avatar
rusty1s committed
85
86
    C = A.dot(B).tocoo().tocsr().tocoo()  # Force coalesce.
    indexC, valueC = from_scipy(C)
rusty1s's avatar
rusty1s committed
87
    return indexC, valueC
rusty1s's avatar
rusty1s committed
88
89


rusty1s's avatar
rusty1s committed
90
def lift(indexA, valueA, indexB, n):  # pragma: no cover
rusty1s's avatar
bw fix  
rusty1s committed
91
92
    idxA = indexA[0] * n + indexA[1]
    idxB = indexB[0] * n + indexB[1]
rusty1s's avatar
rusty1s committed
93

rusty1s's avatar
bw fix  
rusty1s committed
94
95
    max_value = max(idxA.max().item(), idxB.max().item()) + 1
    valueB = valueA.new_zeros(max_value)
rusty1s's avatar
rusty1s committed
96

rusty1s's avatar
bw fix  
rusty1s committed
97
98
99
100
    valueB[idxA] = valueA
    valueB = valueB[idxB]

    return valueB