matmul.py 6.22 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import scipy.sparse
rusty1s's avatar
rusty1s committed
3
from torch_scatter import scatter_add
rusty1s's avatar
rusty1s committed
4
from torch_sparse.utils import ext
rusty1s's avatar
rusty1s committed
5
6
7
8


class SPMM(torch.autograd.Function):
    @staticmethod
rusty1s's avatar
rusty1s committed
9
    def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
rusty1s's avatar
rusty1s committed
10
                reduce):
rusty1s's avatar
rusty1s committed
11
        out, arg_out = ext(mat.is_cuda).spmm(rowptr, col, value, mat, reduce)
rusty1s's avatar
rusty1s committed
12
13

        ctx.reduce = reduce
rusty1s's avatar
rusty1s committed
14
15
        ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
                              csr2csc, arg_out)
rusty1s's avatar
rusty1s committed
16
17

        if reduce == 'min' or reduce == 'max':
rusty1s's avatar
rusty1s committed
18
            ctx.mark_non_differentiable(arg_out)
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
            return out, arg_out
        else:
            return out

    @staticmethod
    def backward(ctx, grad_out, *args):
rusty1s's avatar
rusty1s committed
25
26
        (row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
         arg_out) = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
        invalid_arg_mask = arg_out_ind = None
rusty1s's avatar
rusty1s committed
29
30
31
        if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
                                             or ctx.needs_input_grad[4]):
            invalid_arg_mask = arg_out == col.size(0)
rusty1s's avatar
rusty1s committed
32
33
            arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)

rusty1s's avatar
rusty1s committed
34
        grad_value = None
rusty1s's avatar
rusty1s committed
35
        if ctx.needs_input_grad[3]:
rusty1s's avatar
rusty1s committed
36
            if ctx.reduce in ['sum', 'add', 'mean']:
rusty1s's avatar
rusty1s committed
37
                grad_value = ext(grad_out.is_cuda).spmm_val_bw(
rusty1s's avatar
rusty1s committed
38
                    row, rowptr, col, mat, grad_out, ctx.reduce)
rusty1s's avatar
rusty1s committed
39
40

            elif ctx.reduce in ['min', 'max']:
rusty1s's avatar
rusty1s committed
41
42
                col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
                out = mat.gather(-2, col_tmp).mul_(grad_out)
rusty1s's avatar
rusty1s committed
43
44
45
46
                out.masked_fill_(invalid_arg_mask, 0)
                grad_value = scatter_add(out.flatten(), arg_out.flatten(),
                                         dim=0, dim_size=value.numel() + 1)
                grad_value = grad_value[:-1]
rusty1s's avatar
rusty1s committed
47
48

        grad_mat = None
rusty1s's avatar
rusty1s committed
49
        if ctx.needs_input_grad[4]:
rusty1s's avatar
rusty1s committed
50
51
            if ctx.reduce in ['sum', 'add']:
                value = value[csr2csc] if value is not None else value
rusty1s's avatar
rusty1s committed
52
                grad_mat, _ = ext(grad_out.is_cuda).spmm(
rusty1s's avatar
rusty1s committed
53
                    colptr, row[csr2csc], value, grad_out, 'sum')
rusty1s's avatar
rusty1s committed
54
55

            elif ctx.reduce == 'mean':
rusty1s's avatar
rusty1s committed
56
                count = rowcount[row].to(mat.dtype).clamp_(min=1)
rusty1s's avatar
rusty1s committed
57
                value = count.pow_(-1) if value is None else value / count
rusty1s's avatar
rusty1s committed
58
                row = row[csr2csc]
rusty1s's avatar
rusty1s committed
59
                value = value[csr2csc] if value is not None else value
rusty1s's avatar
rusty1s committed
60
                grad_mat, _ = ext(grad_out.is_cuda).spmm(
rusty1s's avatar
rusty1s committed
61
62
63
64
                    colptr, row, value, grad_out, 'sum')

            elif ctx.reduce in ['min', 'max']:
                if value is not None:
rusty1s's avatar
rusty1s committed
65
                    value = value[arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
66
67
68
                    value = value.mul_(grad_out)
                else:
                    value = grad_out
rusty1s's avatar
rusty1s committed
69
                value.masked_fill_(invalid_arg_mask, 0)
rusty1s's avatar
rusty1s committed
70
71
                col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
                grad_mat = scatter_add(value, col_tmp, dim=-2,
rusty1s's avatar
rusty1s committed
72
73
                                       dim_size=mat.size(-2))

rusty1s's avatar
rusty1s committed
74
        return None, None, None, grad_value, grad_mat, None, None, None, None
rusty1s's avatar
rusty1s committed
75
76


rusty1s's avatar
rusty1s committed
77
78
79
80
class SPSPMM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
        if rowptrA.is_cuda:
rusty1s's avatar
rusty1s committed
81
82
83
            rowptrC, colC, valueC = ext(True).spspmm(rowptrA, colA, valueA,
                                                     rowptrB, colB, valueB, M,
                                                     N, K)
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        else:
            dtype = None
            if valueA is not None:
                dtype = valueA.dtype
            if valueB is not None:
                dtype = valueB.dtype

            if valueA is None:
                valueA = torch.ones(colA.numel(), dtype=dtype)
            A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))

            if valueB is None:
                valueB = torch.ones(colB.numel(), dtype=dtype)
            B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K))

            C = A @ B

            rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
rusty1s's avatar
rusty1s committed
102
103
            colC = torch.from_numpy(C.indices).to(torch.int64)
            valueC = torch.from_numpy(C.data)
rusty1s's avatar
rusty1s committed
104
            valueC = valueC.to(dtype) if dtype is not None else None
rusty1s's avatar
rusty1s committed
105

rusty1s's avatar
rusty1s committed
106
        ctx.mark_non_differentiable(rowptrC, colC)
rusty1s's avatar
rusty1s committed
107

rusty1s's avatar
rusty1s committed
108
109
        # We cannot return `NoneType` in torch.autograd :(
        if valueC is None:
rusty1s's avatar
rusty1s committed
110
            return rowptrC, colC
rusty1s's avatar
rusty1s committed
111
        else:
rusty1s's avatar
rusty1s committed
112
            return rowptrC, colC, valueC
rusty1s's avatar
rusty1s committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    @staticmethod
    def backward(ctx, grad_indexC, grad_rowptrC, *args):
        grad_valueA = None
        if ctx.needs_input_grad[2]:
            raise NotImplementedError

        grad_valueB = None
        if ctx.needs_input_grad[5]:
            raise NotImplementedError

        return (None, None, grad_valueA, None, None, grad_valueB, None, None,
                None)


rusty1s's avatar
rusty1s committed
128
129
def matmul(src, other, reduce='sum'):
    assert src.dim() == 2 and src.size(-1) == other.size(-2)
rusty1s's avatar
rusty1s committed
130

rusty1s's avatar
rusty1s committed
131
    # Sparse-Dense Matrix Multiplication.
rusty1s's avatar
rusty1s committed
132
    if torch.is_tensor(other):
rusty1s's avatar
rusty1s committed
133
        assert reduce in ['sum', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
134
135
136
        rowptr, col, value = src.csr()

        row = None
rusty1s's avatar
rusty1s committed
137
        if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
rusty1s's avatar
rusty1s committed
138
                                                 or other.requires_grad):
rusty1s's avatar
rusty1s committed
139
            row = src.storage.row
rusty1s's avatar
rusty1s committed
140
141
142
143
144

        rowcount = None
        if other.requires_grad and reduce in ['mean']:
            rowcount = src.storage.rowcount

rusty1s's avatar
rusty1s committed
145
146
147
148
        csr2csc = colptr = None
        if other.requires_grad and reduce in ['sum', 'add', 'mean']:
            csr2csc, colptr = src.storage.csr2csc, src.storage.colptr

rusty1s's avatar
rusty1s committed
149
150
        return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
                          csr2csc, reduce)
rusty1s's avatar
rusty1s committed
151

rusty1s's avatar
rusty1s committed
152
    # Sparse-Sparse Matrix Multiplication.
rusty1s's avatar
rusty1s committed
153
154
    elif isinstance(other, src.__class__):
        assert reduce in ['sum', 'add']
rusty1s's avatar
rusty1s committed
155
156
157
        assert src.dim() == 2 and other.dim() == 2
        data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
                            other.size(1))
rusty1s's avatar
rusty1s committed
158
        (rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
rusty1s's avatar
rusty1s committed
159
        sparse_size = torch.Size([src.size(0), other.size(1)])
rusty1s's avatar
rusty1s committed
160
161
        return src.__class__(rowptr=rowptr, col=col, value=value,
                             sparse_size=sparse_size, is_sorted=True)
rusty1s's avatar
rusty1s committed
162
163

    raise ValueError