matmul.py 6.47 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import scipy.sparse
rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
9
10
from torch_sparse import spmm_cpu
from torch_scatter import scatter_add

try:
    from torch_sparse import spmm_cuda
except ImportError:
    spmm_cuda = None

rusty1s's avatar
rusty1s committed
11
12
13
14
15
try:
    from torch_sparse import spspmm_cuda
except ImportError:
    spspmm_cuda = None

rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22

def spmm(is_cuda):
    return spmm_cuda if is_cuda else spmm_cpu


class SPMM(torch.autograd.Function):
    @staticmethod
rusty1s's avatar
rusty1s committed
23
    def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
rusty1s's avatar
rusty1s committed
24
                reduce):
rusty1s's avatar
rusty1s committed
25
        out, arg_out = spmm(mat.is_cuda).spmm(rowptr, col, value, mat, reduce)
rusty1s's avatar
rusty1s committed
26
27

        ctx.reduce = reduce
rusty1s's avatar
rusty1s committed
28
29
        ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
                              csr2csc, arg_out)
rusty1s's avatar
rusty1s committed
30
31

        if reduce == 'min' or reduce == 'max':
rusty1s's avatar
rusty1s committed
32
            ctx.mark_non_differentiable(arg_out)
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
            return out, arg_out
        else:
            return out

    @staticmethod
    def backward(ctx, grad_out, *args):
rusty1s's avatar
rusty1s committed
39
40
        (row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
         arg_out) = ctx.saved_tensors
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
        invalid_arg_mask = arg_out_ind = None
rusty1s's avatar
rusty1s committed
43
44
45
        if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[3]
                                             or ctx.needs_input_grad[4]):
            invalid_arg_mask = arg_out == col.size(0)
rusty1s's avatar
rusty1s committed
46
47
            arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)

rusty1s's avatar
rusty1s committed
48
        grad_value = None
rusty1s's avatar
rusty1s committed
49
        if ctx.needs_input_grad[3]:
rusty1s's avatar
rusty1s committed
50
            if ctx.reduce in ['sum', 'add', 'mean']:
rusty1s's avatar
rusty1s committed
51
                grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rusty1s's avatar
rusty1s committed
52
                    row, rowptr, col, mat, grad_out, ctx.reduce)
rusty1s's avatar
rusty1s committed
53
54

            elif ctx.reduce in ['min', 'max']:
rusty1s's avatar
rusty1s committed
55
56
                col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
                out = mat.gather(-2, col_tmp).mul_(grad_out)
rusty1s's avatar
rusty1s committed
57
58
59
60
                out.masked_fill_(invalid_arg_mask, 0)
                grad_value = scatter_add(out.flatten(), arg_out.flatten(),
                                         dim=0, dim_size=value.numel() + 1)
                grad_value = grad_value[:-1]
rusty1s's avatar
rusty1s committed
61
62

        grad_mat = None
rusty1s's avatar
rusty1s committed
63
        if ctx.needs_input_grad[4]:
rusty1s's avatar
rusty1s committed
64
65
66
            if ctx.reduce in ['sum', 'add']:
                value = value[csr2csc] if value is not None else value
                grad_mat, _ = spmm(grad_out.is_cuda).spmm(
rusty1s's avatar
rusty1s committed
67
                    colptr, row[csr2csc], value, grad_out, 'sum')
rusty1s's avatar
rusty1s committed
68
69

            elif ctx.reduce == 'mean':
rusty1s's avatar
rusty1s committed
70
                count = rowcount[row].to(mat.dtype).clamp_(min=1)
rusty1s's avatar
rusty1s committed
71
                value = count.pow_(-1) if value is None else value / count
rusty1s's avatar
rusty1s committed
72
                row = row[csr2csc]
rusty1s's avatar
rusty1s committed
73
74
75
76
77
78
                value = value[csr2csc] if value is not None else value
                grad_mat, _ = spmm(grad_out.is_cuda).spmm(
                    colptr, row, value, grad_out, 'sum')

            elif ctx.reduce in ['min', 'max']:
                if value is not None:
rusty1s's avatar
rusty1s committed
79
                    value = value[arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
80
81
82
                    value = value.mul_(grad_out)
                else:
                    value = grad_out
rusty1s's avatar
rusty1s committed
83
                value.masked_fill_(invalid_arg_mask, 0)
rusty1s's avatar
rusty1s committed
84
85
                col_tmp = col[arg_out_ind.flatten()].view_as(arg_out)
                grad_mat = scatter_add(value, col_tmp, dim=-2,
rusty1s's avatar
rusty1s committed
86
87
                                       dim_size=mat.size(-2))

rusty1s's avatar
rusty1s committed
88
        return None, None, None, grad_value, grad_mat, None, None, None, None
rusty1s's avatar
rusty1s committed
89
90


rusty1s's avatar
rusty1s committed
91
92
93
94
class SPSPMM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
        if rowptrA.is_cuda:
rusty1s's avatar
rusty1s committed
95
96
97
            rowptrC, colC, valueC = spspmm_cuda.spspmm(rowptrA, colA, valueA,
                                                       rowptrB, colB, valueB,
                                                       M, N, K)
rusty1s's avatar
rusty1s committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        else:
            dtype = None
            if valueA is not None:
                dtype = valueA.dtype
            if valueB is not None:
                dtype = valueB.dtype

            if valueA is None:
                valueA = torch.ones(colA.numel(), dtype=dtype)
            A = scipy.sparse.csr_matrix((valueA, colA, rowptrA), (M, N))

            if valueB is None:
                valueB = torch.ones(colB.numel(), dtype=dtype)
            B = scipy.sparse.csr_matrix((valueB, colB, rowptrB), (N, K))

            C = A @ B

            rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
rusty1s's avatar
rusty1s committed
116
117
            colC = torch.from_numpy(C.indices).to(torch.int64)
            valueC = torch.from_numpy(C.data)
rusty1s's avatar
rusty1s committed
118
            valueC = valueC.to(dtype) if dtype is not None else None
rusty1s's avatar
rusty1s committed
119

rusty1s's avatar
rusty1s committed
120
        ctx.mark_non_differentiable(rowptrC, colC)
rusty1s's avatar
rusty1s committed
121

rusty1s's avatar
rusty1s committed
122
123
        # We cannot return `NoneType` in torch.autograd :(
        if valueC is None:
rusty1s's avatar
rusty1s committed
124
            return rowptrC, colC
rusty1s's avatar
rusty1s committed
125
        else:
rusty1s's avatar
rusty1s committed
126
            return rowptrC, colC, valueC
rusty1s's avatar
rusty1s committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    @staticmethod
    def backward(ctx, grad_indexC, grad_rowptrC, *args):
        grad_valueA = None
        if ctx.needs_input_grad[2]:
            raise NotImplementedError

        grad_valueB = None
        if ctx.needs_input_grad[5]:
            raise NotImplementedError

        return (None, None, grad_valueA, None, None, grad_valueB, None, None,
                None)


rusty1s's avatar
rusty1s committed
142
143
def matmul(src, other, reduce='sum'):
    assert src.dim() == 2 and src.size(-1) == other.size(-2)
rusty1s's avatar
rusty1s committed
144

rusty1s's avatar
rusty1s committed
145
    # Sparse-Dense Matrix Multiplication.
rusty1s's avatar
rusty1s committed
146
    if torch.is_tensor(other):
rusty1s's avatar
rusty1s committed
147
        assert reduce in ['sum', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
148
149
150
        rowptr, col, value = src.csr()

        row = None
rusty1s's avatar
rusty1s committed
151
152
        if reduce in ['sum', 'add', 'mean'] and (src.requires_grad
                                                 or other.reuqires_grad):
rusty1s's avatar
rusty1s committed
153
            row = src.storage.row
rusty1s's avatar
rusty1s committed
154
155
156
157
158

        rowcount = None
        if other.requires_grad and reduce in ['mean']:
            rowcount = src.storage.rowcount

rusty1s's avatar
rusty1s committed
159
160
161
162
        csr2csc = colptr = None
        if other.requires_grad and reduce in ['sum', 'add', 'mean']:
            csr2csc, colptr = src.storage.csr2csc, src.storage.colptr

rusty1s's avatar
rusty1s committed
163
164
        return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
                          csr2csc, reduce)
rusty1s's avatar
rusty1s committed
165

rusty1s's avatar
rusty1s committed
166
    # Sparse-Sparse Matrix Multiplication.
rusty1s's avatar
rusty1s committed
167
168
    elif isinstance(other, src.__class__):
        assert reduce in ['sum', 'add']
rusty1s's avatar
rusty1s committed
169
170
171
        assert src.dim() == 2 and other.dim() == 2
        data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
                            other.size(1))
rusty1s's avatar
rusty1s committed
172
        (rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
rusty1s's avatar
rusty1s committed
173
        sparse_size = torch.Size([src.size(0), other.size(1)])
rusty1s's avatar
rusty1s committed
174
175
        return src.__class__(rowptr=rowptr, col=col, value=value,
                             sparse_size=sparse_size, is_sorted=True)
rusty1s's avatar
rusty1s committed
176
177

    raise ValueError