segment.py 4.59 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from torch_scatter.helpers import min_value, max_value
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
6
    from torch_scatter import segment_cuda, gather_cuda
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class SegmentCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, index, out, dim_size, reduce):
        assert reduce in ['any', 'add', 'mean', 'min', 'max']
        if out is not None:
            ctx.mark_dirty(out)
        ctx.reduce = reduce
        ctx.src_size = list(src.size())

        fill_value = 0
        if out is None:
            dim_size = index.max().item() + 1 if dim_size is None else dim_size
            size = list(src.size())
            size[index.dim() - 1] = dim_size

            if reduce == 'min':
                fill_value = max_value(src.dtype)
            elif reduce == 'max':
                fill_value = min_value(src.dtype)

            out = src.new_full(size, fill_value)

        out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)

        if fill_value != 0:
            out.masked_fill_(out == fill_value, 0)

        ctx.save_for_backward(index, arg_out)

        if reduce == 'min' or reduce == 'max':
            return out, arg_out
        else:
            return out

    @staticmethod
    def backward(ctx, grad_out, *args):
        (index, arg_out), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
            if ctx.reduce == 'any' or ctx.reduce == 'add':
                grad_src = gather_cuda.gather_coo(grad_out, index,
                                                  grad_out.new_empty(src_size))
            elif ctx.reduce == 'mean':
                grad_src = gather_cuda.gather_coo(grad_out, index,
                                                  grad_out.new_empty(src_size))
                count = arg_out
                count = gather_cuda.gather_coo(
                    count, index, count.new_empty(src_size[:index.dim()]))
rusty1s's avatar
rusty1s committed
58
59
                for _ in range(grad_out.dim() - index.dim()):
                    count = count.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66
                grad_src.div_(count)
            elif ctx.reduce == 'min' or ctx.reduce == 'max':
                src_size[index.dim() - 1] += 1
                grad_src = grad_out.new_zeros(src_size).scatter_(
                    index.dim() - 1, arg_out, grad_out)
                grad_src = grad_src.narrow(index.dim() - 1, 0,
                                           src_size[index.dim() - 1] - 1)
rusty1s's avatar
rusty1s committed
67
        return grad_src, None, None, None, None
rusty1s's avatar
rusty1s committed
68
69


rusty1s's avatar
rusty1s committed
70
71
72
class SegmentCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out, reduce):
rusty1s's avatar
rusty1s committed
73
        assert reduce in ['any', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
74
75
76
77

        if out is not None:
            ctx.mark_dirty(out)
        ctx.reduce = reduce
rusty1s's avatar
rusty1s committed
78
        ctx.src_size = list(src.size())
rusty1s's avatar
rusty1s committed
79
80

        out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
rusty1s's avatar
rusty1s committed
81
        ctx.save_for_backward(indptr, arg_out)
rusty1s's avatar
rusty1s committed
82
83
84
85
        return out if arg_out is None else (out, arg_out)

    @staticmethod
    def backward(ctx, grad_out, *args):
rusty1s's avatar
rusty1s committed
86
        (indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size
rusty1s's avatar
rusty1s committed
87
88
89

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
90
91
92
93
94
95
96
97
98
99
100
            if ctx.reduce == 'any' or ctx.reduce == 'add':
                grad_src = gather_cuda.gather_csr(grad_out, indptr,
                                                  grad_out.new_empty(src_size))
            elif ctx.reduce == 'mean':
                grad_src = gather_cuda.gather_csr(grad_out, indptr,
                                                  grad_out.new_empty(src_size))
                indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
                indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
                count = (indptr2 - indptr1).to(grad_src.dtype)
                count = gather_cuda.gather_csr(
                    count, indptr, count.new_empty(src_size[:indptr.dim()]))
rusty1s's avatar
rusty1s committed
101
102
                for _ in range(grad_out.dim() - indptr.dim()):
                    count = count.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
103
104
105
106
107
108
109
                grad_src.div_(count)
            elif ctx.reduce == 'min' or ctx.reduce == 'max':
                src_size[indptr.dim() - 1] += 1
                grad_src = grad_out.new_zeros(src_size).scatter_(
                    indptr.dim() - 1, arg_out, grad_out)
                grad_src = grad_src.narrow(indptr.dim() - 1, 0,
                                           src_size[indptr.dim() - 1] - 1)
rusty1s's avatar
rusty1s committed
110
111
112
113

        return grad_src, None, None, None


rusty1s's avatar
rusty1s committed
114
def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
rusty1s's avatar
rusty1s committed
115
    return SegmentCOO.apply(src, index, out, dim_size, reduce)
rusty1s's avatar
rusty1s committed
116
117
118


def segment_csr(src, indptr, out=None, reduce='add'):
rusty1s's avatar
rusty1s committed
119
    return SegmentCSR.apply(src, indptr, out, reduce)