matmul.py 4.15 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from torch_sparse import spmm_cpu
from torch_scatter import scatter_add

try:
    from torch_sparse import spmm_cuda
except ImportError:
    spmm_cuda = None


def spmm(is_cuda):
    return spmm_cuda if is_cuda else spmm_cpu


class SPMM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, index, rowcount, rowptr, colptr, csr2csc, value, mat,
                reduce):
        out, arg_out = spmm(mat.is_cuda).spmm(rowptr, index[1], value, mat,
                                              reduce)

        ctx.reduce = reduce
        ctx.save_for_backward(index, rowcount, rowptr, colptr, csr2csc, value,
                              mat, arg_out)

        if reduce == 'min' or reduce == 'max':
            return out, arg_out
        else:
            return out

    @staticmethod
    def backward(ctx, grad_out, *args):
        data = ctx.saved_tensors
        index, rowcount, rowptr, colptr, csr2csc, value, mat, arg_out = data

rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
        invalid_arg_mask = arg_out_ind = None
        if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5]
                                             or ctx.needs_input_grad[6]):
            invalid_arg_mask = arg_out == index.size(1)
            arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)

rusty1s's avatar
rusty1s committed
43
44
45
46
        grad_value = None
        if ctx.needs_input_grad[5]:
            if ctx.reduce in ['sum', 'add']:
                grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rusty1s's avatar
rusty1s committed
47
                    index, rowptr, mat, grad_out, ctx.reduce)
rusty1s's avatar
rusty1s committed
48
49
50

            if ctx.reduce == 'mean':
                grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
rusty1s's avatar
rusty1s committed
51
                    index, rowptr, mat, grad_out, ctx.reduce)
rusty1s's avatar
rusty1s committed
52
53

            elif ctx.reduce in ['min', 'max']:
rusty1s's avatar
rusty1s committed
54
                col = index[1][arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
55
                out = mat.gather(-2, col).mul_(grad_out)
rusty1s's avatar
rusty1s committed
56
57
58
59
                out.masked_fill_(invalid_arg_mask, 0)
                grad_value = scatter_add(out.flatten(), arg_out.flatten(),
                                         dim=0, dim_size=value.numel() + 1)
                grad_value = grad_value[:-1]
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

        grad_mat = None
        if ctx.needs_input_grad[6]:
            if ctx.reduce in ['sum', 'add']:
                row = index[0][csr2csc]
                value = value[csr2csc] if value is not None else value
                grad_mat, _ = spmm(grad_out.is_cuda).spmm(
                    colptr, row, value, grad_out, 'sum')

            elif ctx.reduce == 'mean':
                count = rowcount[index[0]].to(mat.dtype).clamp_(min=1)
                value = count.pow_(-1) if value is None else value / count
                row = index[0][csr2csc]
                value = value[csr2csc] if value is not None else value
                grad_mat, _ = spmm(grad_out.is_cuda).spmm(
                    colptr, row, value, grad_out, 'sum')

            elif ctx.reduce in ['min', 'max']:
                if value is not None:
rusty1s's avatar
rusty1s committed
79
                    value = value[arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
80
81
82
                    value = value.mul_(grad_out)
                else:
                    value = grad_out
rusty1s's avatar
rusty1s committed
83
84
                value.masked_fill_(invalid_arg_mask, 0)
                col = index[1][arg_out_ind.flatten()].view_as(arg_out)
rusty1s's avatar
rusty1s committed
85
86
87
88
89
90
91
92
                grad_mat = scatter_add(value, col, dim=-2,
                                       dim_size=mat.size(-2))

        return None, None, None, None, None, grad_value, grad_mat, None


def matmul(src, other, reduce='sum'):
    assert src.dim() == 2 and src.size(-1) == other.size(-2)
rusty1s's avatar
rusty1s committed
93
94

    if torch.is_tensor(other):
rusty1s's avatar
rusty1s committed
95
96
97
98
99
100
101
        assert reduce in ['sum', 'add', 'mean', 'min', 'max']
        (index, value), rowptr = src.coo(), src.storage.rowptr

        rowcount = None
        if other.requires_grad and reduce in ['mean']:
            rowcount = src.storage.rowcount

rusty1s's avatar
rusty1s committed
102
103
104
105
        csr2csc = colptr = None
        if other.requires_grad and reduce in ['sum', 'add', 'mean']:
            csr2csc, colptr = src.storage.csr2csc, src.storage.colptr

rusty1s's avatar
rusty1s committed
106
107
108
109
110
        return SPMM.apply(index, rowcount, rowptr, colptr, csr2csc, value,
                          other, reduce)

    elif isinstance(other, src.__class__):
        assert reduce in ['sum', 'add']
rusty1s's avatar
rusty1s committed
111
        raise NotImplementedError
rusty1s's avatar
rusty1s committed
112
113

    raise ValueError