segment.py 2.8 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from torch_scatter.helpers import min_value, max_value
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
6
    from torch_scatter import segment_cuda, gather_cuda
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
10
11
class SegmentCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out, reduce):
rusty1s's avatar
rusty1s committed
12
        assert reduce in ['any', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
13
14
15
16

        if out is not None:
            ctx.mark_dirty(out)
        ctx.reduce = reduce
rusty1s's avatar
rusty1s committed
17
        ctx.src_size = list(src.size())
rusty1s's avatar
rusty1s committed
18
19

        out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
rusty1s's avatar
rusty1s committed
20
        ctx.save_for_backward(indptr, arg_out)
rusty1s's avatar
rusty1s committed
21
22
23
24
        return out if arg_out is None else (out, arg_out)

    @staticmethod
    def backward(ctx, grad_out, *args):
rusty1s's avatar
rusty1s committed
25
        (indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size
rusty1s's avatar
rusty1s committed
26
27
28

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
            if ctx.reduce == 'any' or ctx.reduce == 'add':
                grad_src = gather_cuda.gather_csr(grad_out, indptr,
                                                  grad_out.new_empty(src_size))
            elif ctx.reduce == 'mean':
                grad_src = gather_cuda.gather_csr(grad_out, indptr,
                                                  grad_out.new_empty(src_size))
                indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
                indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
                count = (indptr2 - indptr1).to(grad_src.dtype)
                count = gather_cuda.gather_csr(
                    count, indptr, count.new_empty(src_size[:indptr.dim()]))
                grad_src.div_(count)
            elif ctx.reduce == 'min' or ctx.reduce == 'max':
                src_size[indptr.dim() - 1] += 1
                grad_src = grad_out.new_zeros(src_size).scatter_(
                    indptr.dim() - 1, arg_out, grad_out)
                grad_src = grad_src.narrow(indptr.dim() - 1, 0,
                                           src_size[indptr.dim() - 1] - 1)
rusty1s's avatar
rusty1s committed
47
48
49
50

        return grad_src, None, None, None


rusty1s's avatar
rusty1s committed
51
def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
rusty1s's avatar
rusty1s committed
52
    assert reduce in ['any', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
53
54

    fill_value = 0
rusty1s's avatar
rusty1s committed
55
    if out is None:
rusty1s's avatar
rusty1s committed
56
57
58
        dim_size = index.max().item() + 1 if dim_size is None else dim_size
        size = list(src.size())
        size[index.dim() - 1] = dim_size
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65

        if reduce == 'min':
            fill_value = max_value(src.dtype)
        elif reduce == 'max':
            fill_value = min_value(src.dtype)

        out = src.new_full(size, fill_value)
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
    out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
rusty1s's avatar
rusty1s committed
68
69
70
71

    if fill_value != 0:
        out.masked_fill_(out == fill_value, 0)

rusty1s's avatar
rusty1s committed
72
73
74
75
    if reduce == 'min' or reduce == 'max':
        return out, arg_out
    else:
        return out
rusty1s's avatar
rusty1s committed
76
77
78


def segment_csr(src, indptr, out=None, reduce='add'):
rusty1s's avatar
rusty1s committed
79
    return SegmentCSR.apply(src, indptr, out, reduce)