matmul.py 6.38 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import scipy.sparse
rusty1s's avatar
rusty1s committed
3
from torch_scatter import scatter_add
rusty1s's avatar
rusty1s committed
4
5

ext = None
rusty1s's avatar
rusty1s committed
6
7
8
9


class SPMM(torch.autograd.Function):
    @staticmethod
rusty1s's avatar
rusty1s committed
10
    def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
rusty1s's avatar
rusty1s committed
11
                reduce):
rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
        if mat.is_cuda:
            out, arg_out = torch.ops.torch_sparse_cuda.spmm(
                rowptr, col, value, mat, reduce)
        else:
            out, arg_out = torch.ops.torch_sparse_cpu.spmm(
                rowptr, col, value, mat, reduce)
rusty1s's avatar
rusty1s committed
18
19

        ctx.reduce = reduce
rusty1s's avatar
rusty1s committed
20
21
        ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
                              csr2csc, arg_out)
rusty1s's avatar
rusty1s committed
22
23

        if reduce == 'min' or reduce == 'max':
rusty1s's avatar
rusty1s committed
24
            ctx.mark_non_differentiable(arg_out)
rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
            return out, arg_out
        else:
            return out

    @staticmethod
    def backward(ctx, grad_out, *args):
rusty1s's avatar
rusty1s committed
31
32
        (row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
         arg_out) = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
        invalid_arg_mask = arg_out_ind = None
rusty1s's avatar
rusty1s committed
35
36
37
        if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
                                             or ctx.needs_input_grad[4]):
            invalid_arg_mask = arg_out == col.size(0)
rusty1s's avatar
rusty1s committed
38
39
            arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)

rusty1s's avatar
rusty1s committed
40
        grad_value = None
rusty1s's avatar
rusty1s committed
41
        if ctx.needs_input_grad[3]:
rusty1s's avatar
rusty1s committed
42
            if ctx.reduce in ['sum', 'add', 'mean']:
rusty1s's avatar
rusty1s committed
43
                grad_value = ext(grad_out.is_cuda).spmm_val_bw(
rusty1s's avatar
rusty1s committed
44
                    row, rowptr, col, mat, grad_out, ctx.reduce)
rusty1s's avatar
rusty1s committed
45
46

            elif ctx.reduce in ['min', 'max']:
rusty1s's avatar
rusty1s committed
47
48
                col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
                out = mat.gather(-2, col_tmp).mul_(grad_out)
rusty1s's avatar
rusty1s committed
49
50
51
52
                out.masked_fill_(invalid_arg_mask, 0)
                grad_value = scatter_add(out.flatten(), arg_out.flatten(),
                                         dim=0, dim_size=value.numel() + 1)
                grad_value = grad_value[:-1]
rusty1s's avatar
rusty1s committed
53
54

        grad_mat = None
rusty1s's avatar
rusty1s committed
55
        if ctx.needs_input_grad[4]:
rusty1s's avatar
rusty1s committed
56
57
            if ctx.reduce in ['sum', 'add']:
                value = value[csr2csc] if value is not None else value
rusty1s's avatar
rusty1s committed
58
                grad_mat, _ = ext(grad_out.is_cuda).spmm(
rusty1s's avatar
rusty1s committed
59
                    colptr, row[csr2csc], value, grad_out, 'sum')
rusty1s's avatar
rusty1s committed
60
61

            elif ctx.reduce == 'mean':
rusty1s's avatar
rusty1s committed
62
                count = rowcount[row].to(mat.dtype).clamp_(min=1)
rusty1s's avatar
rusty1s committed
63
                value = count.pow_(-1) if value is None else value / count
rusty1s's avatar
rusty1s committed
64
                row = row[csr2csc]
rusty1s's avatar
rusty1s committed
65
                value = value[csr2csc] if value is not None else value
rusty1s's avatar
rusty1s committed
66
                grad_mat, _ = ext(grad_out.is_cuda).spmm(
rusty1s's avatar
rusty1s committed
67
68
69
70
                    colptr, row, value, grad_out, 'sum')

            elif ctx.reduce in ['min', 'max']:
                if value is not None:
rusty1s's avatar
rusty1s committed
71
                    value = value[arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
72
73
74
                    value = value.mul_(grad_out)
                else:
                    value = grad_out
rusty1s's avatar
rusty1s committed
75
                value.masked_fill_(invalid_arg_mask, 0)
rusty1s's avatar
rusty1s committed
76
77
                col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
                grad_mat = scatter_add(value, col_tmp, dim=-2,
rusty1s's avatar
rusty1s committed
78
79
                                       dim_size=mat.size(-2))

rusty1s's avatar
rusty1s committed
80
        return None, None, None, grad_value, grad_mat, None, None, None, None
rusty1s's avatar
rusty1s committed
81
82


rusty1s's avatar
rusty1s committed
83
84
85
86
class SPSPMM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
        if rowptrA.is_cuda:
rusty1s's avatar
rusty1s committed
87
88
89
            rowptrC, colC, valueC = ext(True).spspmm(rowptrA, colA, valueA,
                                                     rowptrB, colB, valueB, M,
                                                     N, K)
rusty1s's avatar
rusty1s committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        else:
            dtype = None
            if valueA is not None:
                dtype = valueA.dtype
            if valueB is not None:
                dtype = valueB.dtype

            if valueA is None:
                valueA = torch.ones(colA.numel(), dtype=dtype)
            A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))

            if valueB is None:
                valueB = torch.ones(colB.numel(), dtype=dtype)
            B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K))

            C = A @ B

            rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
rusty1s's avatar
rusty1s committed
108
109
            colC = torch.from_numpy(C.indices).to(torch.int64)
            valueC = torch.from_numpy(C.data)
rusty1s's avatar
rusty1s committed
110
            valueC = valueC.to(dtype) if dtype is not None else None
rusty1s's avatar
rusty1s committed
111

rusty1s's avatar
rusty1s committed
112
        ctx.mark_non_differentiable(rowptrC, colC)
rusty1s's avatar
rusty1s committed
113

rusty1s's avatar
rusty1s committed
114
115
        # We cannot return `NoneType` in torch.autograd :(
        if valueC is None:
rusty1s's avatar
rusty1s committed
116
            return rowptrC, colC
rusty1s's avatar
rusty1s committed
117
        else:
rusty1s's avatar
rusty1s committed
118
            return rowptrC, colC, valueC
rusty1s's avatar
rusty1s committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    @staticmethod
    def backward(ctx, grad_indexC, grad_rowptrC, *args):
        grad_valueA = None
        if ctx.needs_input_grad[2]:
            raise NotImplementedError

        grad_valueB = None
        if ctx.needs_input_grad[5]:
            raise NotImplementedError

        return (None, None, grad_valueA, None, None, grad_valueB, None, None,
                None)


rusty1s's avatar
rusty1s committed
134
135
def matmul(src, other, reduce='sum'):
    assert src.dim() == 2 and src.size(-1) == other.size(-2)
rusty1s's avatar
rusty1s committed
136

rusty1s's avatar
rusty1s committed
137
    # Sparse-Dense Matrix Multiplication.
rusty1s's avatar
rusty1s committed
138
    if torch.is_tensor(other):
rusty1s's avatar
rusty1s committed
139
        assert reduce in ['sum', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
140
141
142
        rowptr, col, value = src.csr()

        row = None
rusty1s's avatar
rusty1s committed
143
        if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
rusty1s's avatar
rusty1s committed
144
                                                 or other.requires_grad):
rusty1s's avatar
rusty1s committed
145
            row = src.storage.row
rusty1s's avatar
rusty1s committed
146
147
148
149
150

        rowcount = None
        if other.requires_grad and reduce in ['mean']:
            rowcount = src.storage.rowcount

rusty1s's avatar
rusty1s committed
151
152
153
154
        csr2csc = colptr = None
        if other.requires_grad and reduce in ['sum', 'add', 'mean']:
            csr2csc, colptr = src.storage.csr2csc, src.storage.colptr

rusty1s's avatar
rusty1s committed
155
156
        return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
                          csr2csc, reduce)
rusty1s's avatar
rusty1s committed
157

rusty1s's avatar
rusty1s committed
158
    # Sparse-Sparse Matrix Multiplication.
rusty1s's avatar
rusty1s committed
159
160
    elif isinstance(other, src.__class__):
        assert reduce in ['sum', 'add']
rusty1s's avatar
rusty1s committed
161
162
163
        assert src.dim() == 2 and other.dim() == 2
        data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
                            other.size(1))
rusty1s's avatar
rusty1s committed
164
        (rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
rusty1s's avatar
rusty1s committed
165
        sparse_size = torch.Size([src.size(0), other.size(1)])
rusty1s's avatar
rusty1s committed
166
167
        return src.__class__(rowptr=rowptr, col=col, value=value,
                             sparse_size=sparse_size, is_sorted=True)
rusty1s's avatar
rusty1s committed
168
169

    raise ValueError