"sgl-kernel/pyproject_rocm.toml" did not exist on "c553e1604c4e662801dbe0e222c0c0c293afbcb7"
spspmm.py 4.03 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
2
from torch_sparse import transpose, to_scipy, from_scipy, coalesce
rusty1s's avatar
rusty1s committed
3
4

import torch_sparse.spspmm_cpu
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
if torch.cuda.is_available():
7
    import torch_sparse.spspmm_cuda
rusty1s's avatar
rusty1s committed
8
9


10
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
rusty1s's avatar
docs  
rusty1s committed
11
    """Matrix product of two sparse tensors. Both input sparse matrices need to
12
    be coalesced (use the :obj:`coalesce` attribute to force).
rusty1s's avatar
docs  
rusty1s committed
13
14
15
16
17
18

    Args:
        indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
        valueA (:class:`Tensor`): The value tensor of first sparse matrix.
        indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
        valueB (:class:`Tensor`): The value tensor of second sparse matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
19
        m (int): The first dimension of first corresponding dense matrix.
rusty1s's avatar
linting  
rusty1s committed
20
21
        k (int): The second dimension of first corresponding dense matrix and
            first dimension of second corresponding dense matrix.
ekagra-ranjan's avatar
ekagra-ranjan committed
22
        n (int): The second dimension of second corresponding dense matrix.
23
24
        coalesced (bool, optional): If set to :obj:`False`, will coalesce both
            input sparse matrices (default: :obj:`True`).
rusty1s's avatar
docs  
rusty1s committed
25
26
27

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
28
29
30
31
    if indexA.is_cuda and coalesced:
        indexA, valueA = coalesce(indexA, valueA, m, k)
        indexB, valueB = coalesce(indexB, valueB, k, n)

32
33
    index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
    return index.detach(), value
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
docs  
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
class SpSpMM(torch.autograd.Function):
rusty1s's avatar
rusty1s committed
37
    @staticmethod
rusty1s's avatar
rusty1s committed
38
39
40
41
42
    def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
        indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
        ctx.m, ctx.k, ctx.n = m, k, n
        ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
        return indexC, valueC
rusty1s's avatar
rusty1s committed
43
44

    @staticmethod
rusty1s's avatar
rusty1s committed
45
    def backward(ctx, grad_indexC, grad_valueC):
rusty1s's avatar
rusty1s committed
46
47
        m, k = ctx.m, ctx.k
        n = ctx.n
AntoinePrv's avatar
AntoinePrv committed
48
        indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
        grad_valueA = grad_valueB = None
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
61
        if not grad_valueC.is_cuda:
            if ctx.needs_input_grad[1] or ctx.needs_input_grad[1]:
                grad_valueC = grad_valueC.clone()

            if ctx.needs_input_grad[1]:
                grad_valueA = torch_sparse.spspmm_cpu.spspmm_bw(
                    indexA, indexC.detach(), grad_valueC, indexB.detach(),
                    valueB, m, k)

            if ctx.needs_input_grad[3]:
62
63
                indexA, valueA = transpose(indexA, valueA, m, k)
                indexC, grad_valueC = transpose(indexC, grad_valueC, m, n)
rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
72
73
                grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
                    indexB, indexA.detach(), valueA, indexC.detach(),
                    grad_valueC, k, n)
        else:
            if ctx.needs_input_grad[1]:
                grad_valueA = torch_sparse.spspmm_cuda.spspmm_bw(
                    indexA, indexC.detach(), grad_valueC.clone(),
                    indexB.detach(), valueB, m, k)

            if ctx.needs_input_grad[3]:
74
                indexA_T, valueA_T = transpose(indexA, valueA, m, k)
rusty1s's avatar
rusty1s committed
75
76
77
                grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
                                              grad_valueC, k, m, n)
                grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
rusty1s's avatar
rusty1s committed
78

rusty1s's avatar
rusty1s committed
79
        return None, grad_valueA, None, grad_valueB, None, None, None
rusty1s's avatar
rusty1s committed
80
81


rusty1s's avatar
rusty1s committed
82
83
def mm(indexA, valueA, indexB, valueB, m, k, n):
    assert valueA.dtype == valueB.dtype
rusty1s's avatar
rusty1s committed
84

rusty1s's avatar
rusty1s committed
85
    if indexA.is_cuda:
86
87
        return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
                                               m, k, n)
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
rusty1s committed
89
90
    A = to_scipy(indexA, valueA, m, k)
    B = to_scipy(indexB, valueB, k, n)
rusty1s's avatar
rusty1s committed
91
92
    C = A.dot(B).tocoo().tocsr().tocoo()  # Force coalesce.
    indexC, valueC = from_scipy(C)
rusty1s's avatar
rusty1s committed
93
    return indexC, valueC
rusty1s's avatar
rusty1s committed
94
95


rusty1s's avatar
rusty1s committed
96
def lift(indexA, valueA, indexB, n):  # pragma: no cover
rusty1s's avatar
bw fix  
rusty1s committed
97
98
    idxA = indexA[0] * n + indexA[1]
    idxB = indexB[0] * n + indexB[1]
rusty1s's avatar
rusty1s committed
99

rusty1s's avatar
bw fix  
rusty1s committed
100
101
    max_value = max(idxA.max().item(), idxB.max().item()) + 1
    valueB = valueA.new_zeros(max_value)
rusty1s's avatar
rusty1s committed
102

rusty1s's avatar
bw fix  
rusty1s committed
103
104
105
106
    valueB[idxA] = valueA
    valueB = valueB[idxB]

    return valueB