segment.py 1.33 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch

if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
4
5
6
    from torch_scatter import segment_cuda


rusty1s's avatar
rusty1s committed
7
8
9
class SegmentCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out, reduce):
rusty1s's avatar
rusty1s committed
10
        assert reduce in ['any', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

        if out is not None:
            ctx.mark_dirty(out)
        ctx.reduce = reduce
        ctx.save_for_backward(src, indptr)

        out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
        return out if arg_out is None else (out, arg_out)

    @staticmethod
    def backward(ctx, grad_out, *args):
        src, indptr = ctx.saved_tensors

        grad_src = None
        if ctx.needs_input_grad[0]:
            grad_src = src

        return grad_src, None, None, None


rusty1s's avatar
rusty1s committed
31
def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
rusty1s's avatar
rusty1s committed
32
    assert reduce in ['any', 'add', 'mean', 'min', 'max']
rusty1s's avatar
rusty1s committed
33
    if out is None:  # TODO: MOVE TO CPP
rusty1s's avatar
rusty1s committed
34
35
36
        dim_size = index.max().item() + 1 if dim_size is None else dim_size
        size = list(src.size())
        size[index.dim() - 1] = dim_size
rusty1s's avatar
rusty1s committed
37
        out = src.new_zeros(size)  # TODO: DEPENDS ON REDUCE
rusty1s's avatar
rusty1s committed
38
39
40
41
42
    out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
    return out if arg_out is None else (out, arg_out)


def segment_csr(src, indptr, out=None, reduce='add'):
rusty1s's avatar
rusty1s committed
43
    return SegmentCSR.apply(src, indptr, out, reduce)