matmul.py 6.6 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import scipy.sparse
rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
9
10
from torch_sparse import spmm_cpu
from torch_scatter import scatter_add

try:
    from torch_sparse import spmm_cuda
except ImportError:
    spmm_cuda = None

rusty1s's avatar
rusty1s committed
11
12
13
14
15
try:
    from torch_sparse import spspmm_cuda
except ImportError:
    spspmm_cuda = None

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22

def spmm(is_cuda):
    return spmm_cuda if is_cuda else spmm_cpu


class SPMM(torch.autograd.Function):
    @staticmethod
rusty1s's avatar
rusty1s committed
23
    def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
rusty1s's avatar
rusty1s committed
24
                reduce):
rusty1s's avatar
rusty1s committed
25
        out, arg_out = spmm(mat.is_cuda).spmm(rowptr, col, value, mat, reduce)
rusty1s's avatar
rusty1s committed
26
27

        ctx.reduce = reduce
rusty1s's avatar
rusty1s committed
28
29
        ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
                              csr2csc, arg_out)
rusty1s's avatar
rusty1s committed
30
31

        if reduce == 'min' or reduce == 'max':
rusty1s's avatar
rusty1s committed
32
            ctx.mark_non_differentiable(arg_out)
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
            return out, arg_out
        else:
            return out

    @staticmethod
    def backward(ctx, grad_out, *args):
rusty1s's avatar
rusty1s committed
39
40
        (row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
         arg_out) = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
43
44
        invalid_arg_mask = arg_out_ind = None
        if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5]
                                             or ctx.needs_input_grad[6]):
rusty1s's avatar
rusty1s committed
45
            invalid_arg_mask = arg_out == row.size(0)
rusty1s's avatar
rusty1s committed
46
47
            arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)

rusty1s's avatar
rusty1s committed
48
        grad_value = None
rusty1s's avatar
rusty1s committed
49
        if ctx.needs_input_grad[3]:
rusty1s's avatar
rusty1s committed
50
51
            if ctx.reduce in ['sum', 'add']:
                grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rusty1s's avatar
rusty1s committed
52
                    row, rowptr, col, mat, grad_out, ctx.reduce)
rusty1s's avatar
rusty1s committed
53
54
55

            if ctx.reduce == 'mean':
                grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rusty1s's avatar
rusty1s committed
56
                    row, rowptr, col, mat, grad_out, ctx.reduce)
rusty1s's avatar
rusty1s committed
57
58

            elif ctx.reduce in ['min', 'max']:
rusty1s's avatar
rusty1s committed
59
                col = col[arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
60
                out = mat.gather(-2, col).mul_(grad_out)
rusty1s's avatar
rusty1s committed
61
62
63
64
                out.masked_fill_(invalid_arg_mask, 0)
                grad_value = scatter_add(out.flatten(), arg_out.flatten(),
                                         dim=0, dim_size=value.numel() + 1)
                grad_value = grad_value[:-1]
rusty1s's avatar
rusty1s committed
65
66

        grad_mat = None
rusty1s's avatar
rusty1s committed
67
        if ctx.needs_input_grad[4]:
rusty1s's avatar
rusty1s committed
68
69
70
            if ctx.reduce in ['sum', 'add']:
                value = value[csr2csc] if value is not None else value
                grad_mat, _ = spmm(grad_out.is_cuda).spmm(
rusty1s's avatar
rusty1s committed
71
                    colptr, row[csr2csc], value, grad_out, 'sum')
rusty1s's avatar
rusty1s committed
72
73

            elif ctx.reduce == 'mean':
rusty1s's avatar
rusty1s committed
74
                count = rowcount[row].to(mat.dtype).clamp_(min=1)
rusty1s's avatar
rusty1s committed
75
                value = count.pow_(-1) if value is None else value / count
rusty1s's avatar
rusty1s committed
76
                row = row[csr2csc]
rusty1s's avatar
rusty1s committed
77
78
79
80
81
82
                value = value[csr2csc] if value is not None else value
                grad_mat, _ = spmm(grad_out.is_cuda).spmm(
                    colptr, row, value, grad_out, 'sum')

            elif ctx.reduce in ['min', 'max']:
                if value is not None:
rusty1s's avatar
rusty1s committed
83
                    value = value[arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
84
85
86
                    value = value.mul_(grad_out)
                else:
                    value = grad_out
rusty1s's avatar
rusty1s committed
87
                value.masked_fill_(invalid_arg_mask, 0)
rusty1s's avatar
rusty1s committed
88
                col = col[arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
89
90
91
                grad_mat = scatter_add(value, col, dim=-2,
                                       dim_size=mat.size(-2))

rusty1s's avatar
rusty1s committed
92
        return None, None, None, grad_value, grad_mat, None, None, None, None
rusty1s's avatar
rusty1s committed
93
94


rusty1s's avatar
rusty1s committed
95
96
97
98
class SPSPMM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
        if rowptrA.is_cuda:
rusty1s's avatar
rusty1s committed
99
100
101
            rowptrC, colC, valueC = spspmm_cuda.spspmm(rowptrA, colA, valueA,
                                                       rowptrB, colB, valueB,
                                                       M, N, K)
rusty1s's avatar
rusty1s committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        else:
            dtype = None
            if valueA is not None:
                dtype = valueA.dtype
            if valueB is not None:
                dtype = valueB.dtype

            if valueA is None:
                valueA = torch.ones(colA.numel(), dtype=dtype)
            A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))

            if valueB is None:
                valueB = torch.ones(colB.numel(), dtype=dtype)
            B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K))

            C = A @ B

            rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
rusty1s's avatar
rusty1s committed
120
121
122
            colC = torch.from_numpy(C.indices).to(torch.int64)
            valueC = torch.from_numpy(C.data)
            valueC = valueC.to(dtype) if dtype is not None else valueC
rusty1s's avatar
rusty1s committed
123

rusty1s's avatar
rusty1s committed
124
        ctx.mark_non_differentiable(rowptrC, colC)
rusty1s's avatar
rusty1s committed
125

rusty1s's avatar
rusty1s committed
126
127
        # We cannot return `NoneType` in torch.autograd :(
        if valueC is None:
rusty1s's avatar
rusty1s committed
128
            return rowptrC, colC
rusty1s's avatar
rusty1s committed
129
        else:
rusty1s's avatar
rusty1s committed
130
            return rowptrC, colC, valueC
rusty1s's avatar
rusty1s committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    @staticmethod
    def backward(ctx, grad_indexC, grad_rowptrC, *args):
        grad_valueA = None
        if ctx.needs_input_grad[2]:
            raise NotImplementedError

        grad_valueB = None
        if ctx.needs_input_grad[5]:
            raise NotImplementedError

        return (None, None, grad_valueA, None, None, grad_valueB, None, None,
                None)


rusty1s's avatar
rusty1s committed
146
147
def matmul(src, other, reduce='sum'):
    assert src.dim() == 2 and src.size(-1) == other.size(-2)
rusty1s's avatar
rusty1s committed
148

rusty1s's avatar
rusty1s committed
149
    # Sparse-Dense Matrix Multiplication.
rusty1s's avatar
rusty1s committed
150
    if torch.is_tensor(other):
rusty1s's avatar
rusty1s committed
151
        assert reduce in ['sum', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
152
153
154
155
156
157
        rowptr, col, value = src.csr()

        row = None
        if reduce in ['sum', 'add'] and (src.requires_grad
                                         or other.reuqires_grad):
            row = src.storage.row
rusty1s's avatar
rusty1s committed
158
159
160
161
162

        rowcount = None
        if other.requires_grad and reduce in ['mean']:
            rowcount = src.storage.rowcount

rusty1s's avatar
rusty1s committed
163
164
165
166
        csr2csc = colptr = None
        if other.requires_grad and reduce in ['sum', 'add', 'mean']:
            csr2csc, colptr = src.storage.csr2csc, src.storage.colptr

rusty1s's avatar
rusty1s committed
167
168
        return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
                          csr2csc, reduce)
rusty1s's avatar
rusty1s committed
169

rusty1s's avatar
rusty1s committed
170
    # Sparse-Sparse Matrix Multiplication.
rusty1s's avatar
rusty1s committed
171
172
    elif isinstance(other, src.__class__):
        assert reduce in ['sum', 'add']
rusty1s's avatar
rusty1s committed
173
174
175
        assert src.dim() == 2 and other.dim() == 2
        data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
                            other.size(1))
rusty1s's avatar
rusty1s committed
176
        (rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
rusty1s's avatar
rusty1s committed
177
        sparse_size = torch.Size([src.size(0), other.size(1)])
rusty1s's avatar
rusty1s committed
178
179
        return src.__class__(rowptr=rowptr, col=col, value=value,
                             sparse_size=sparse_size, is_sorted=True)
rusty1s's avatar
rusty1s committed
180
181

    raise ValueError