segment.py 2.47 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch

if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
4
    from torch_scatter import segment_cuda, gather_cuda
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
8
9
class SegmentCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out, reduce):
rusty1s's avatar
rusty1s committed
10
        assert reduce in ['any', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
11
12
13
14

        if out is not None:
            ctx.mark_dirty(out)
        ctx.reduce = reduce
rusty1s's avatar
rusty1s committed
15
        ctx.src_size = list(src.size())
rusty1s's avatar
rusty1s committed
16
17

        out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
rusty1s's avatar
rusty1s committed
18
        ctx.save_for_backward(indptr, arg_out)
rusty1s's avatar
rusty1s committed
19
20
21
22
        return out if arg_out is None else (out, arg_out)

    @staticmethod
    def backward(ctx, grad_out, *args):
rusty1s's avatar
rusty1s committed
23
        (indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size
rusty1s's avatar
rusty1s committed
24
25
26

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
            if ctx.reduce == 'any' or ctx.reduce == 'add':
                grad_src = gather_cuda.gather_csr(grad_out, indptr,
                                                  grad_out.new_empty(src_size))
            elif ctx.reduce == 'mean':
                grad_src = gather_cuda.gather_csr(grad_out, indptr,
                                                  grad_out.new_empty(src_size))
                indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
                indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
                count = (indptr2 - indptr1).to(grad_src.dtype)
                count = gather_cuda.gather_csr(
                    count, indptr, count.new_empty(src_size[:indptr.dim()]))
                grad_src.div_(count)
            elif ctx.reduce == 'min' or ctx.reduce == 'max':
                src_size[indptr.dim() - 1] += 1
                grad_src = grad_out.new_zeros(src_size).scatter_(
                    indptr.dim() - 1, arg_out, grad_out)
                grad_src = grad_src.narrow(indptr.dim() - 1, 0,
                                           src_size[indptr.dim() - 1] - 1)
rusty1s's avatar
rusty1s committed
45
46
47
48

        return grad_src, None, None, None


rusty1s's avatar
rusty1s committed
49
def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
rusty1s's avatar
rusty1s committed
50
    assert reduce in ['any', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
51
    if out is None:
rusty1s's avatar
rusty1s committed
52
53
54
        dim_size = index.max().item() + 1 if dim_size is None else dim_size
        size = list(src.size())
        size[index.dim() - 1] = dim_size
rusty1s's avatar
rusty1s committed
55
        out = src.new_zeros(size)  # TODO: DEPENDS ON REDUCE
rusty1s's avatar
rusty1s committed
56
57
58
59
60
    out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
    return out if arg_out is None else (out, arg_out)


def segment_csr(src, indptr, out=None, reduce='add'):
rusty1s's avatar
rusty1s committed
61
    return SegmentCSR.apply(src, indptr, out, reduce)