spspmm.py 3.68 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
2
from torch_sparse import transpose, to_scipy, from_scipy
rusty1s's avatar
rusty1s committed
3
4

import torch_sparse.spspmm_cpu
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
if torch.cuda.is_available():
7
    import torch_sparse.spspmm_cuda
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def spspmm(indexA, valueA, indexB, valueB, m, k, n):
rusty1s's avatar
docs  
rusty1s committed
11
12
13
14
15
16
17
18
    """Matrix product of two sparse tensors. Both input sparse matrices need to
    be coalesced.

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
19
20
21
22
        m (int): The first dimension of first corresponding dense matrix.
        k (int): The second dimension of first corresponding dense matrix and first
            dimension of second corresponding dense matrix.
        n (int): The second dimension of second corresponding dense matrix.
rusty1s's avatar
docs  
rusty1s committed
23
24
25

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
26
27
    index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
    return index.detach(), value
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
docs  
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
31
    @staticmethod
rusty1s's avatar
rusty1s committed
32
33
34
35
36
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
37
38

    @staticmethod
rusty1s's avatar
rusty1s committed
39
    def backward(ctx, grad_indexC, grad_valueC):
rusty1s's avatar
rusty1s committed
40
41
        m, k = ctx.m, ctx.k
        n = ctx.n
AntoinePrv's avatar
AntoinePrv committed
42
        indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
52
53
54
55
        if not grad_valueC.is_cuda:
            if ctx.needs_input_grad[1] or ctx.needs_input_grad[1]:
                grad_valueC = grad_valueC.clone()

            if ctx.needs_input_grad[1]:
                grad_valueA = torch_sparse.spspmm_cpu.spspmm_bw(
                    indexA, indexC.detach(), grad_valueC, indexB.detach(),
                    valueB, m, k)

            if ctx.needs_input_grad[3]:
56
57
                indexA, valueA = transpose(indexA, valueA, m, k)
                indexC, grad_valueC = transpose(indexC, grad_valueC, m, n)
rusty1s's avatar
rusty1s committed
58
59
60
61
62
63
64
65
66
67
                grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
                    indexB, indexA.detach(), valueA, indexC.detach(),
                    grad_valueC, k, n)
        else:
            if ctx.needs_input_grad[1]:
                grad_valueA = torch_sparse.spspmm_cuda.spspmm_bw(
                    indexA, indexC.detach(), grad_valueC.clone(),
                    indexB.detach(), valueB, m, k)

            if ctx.needs_input_grad[3]:
68
                indexA_T, valueA_T = transpose(indexA, valueA, m, k)
rusty1s's avatar
rusty1s committed
69
70
71
                grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                              grad_valueC, k, m, n)
                grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
74
75


rusty1s's avatar
rusty1s committed
76
77
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79
    if indexA.is_cuda:
80
81
        return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
                                               m, k, n)
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
84
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
rusty1s's avatar
rusty1s committed
85
86
    C = A.dot(B).tocoo().tocsr().tocoo()  # Force coalesce.
    indexC, valueC = from_scipy(C)
rusty1s's avatar
rusty1s committed
87
    return indexC, valueC
rusty1s's avatar
rusty1s committed
88
89


rusty1s's avatar
rusty1s committed
90
def lift(indexA, valueA, indexB, n):  # pragma: no cover
rusty1s's avatar
bw fix  
rusty1s committed
91
92
    idxA = indexA[0] * n + indexA[1]
    idxB = indexB[0] * n + indexB[1]
rusty1s's avatar
rusty1s committed
93

rusty1s's avatar
bw fix  
rusty1s committed
94
95
    max_value = max(idxA.max().item(), idxB.max().item()) + 1
    valueB = valueA.new_zeros(max_value)
rusty1s's avatar
rusty1s committed
96

rusty1s's avatar
bw fix  
rusty1s committed
97
98
99
100
    valueB[idxA] = valueA
    valueB = valueB[idxB]

    return valueB