gather.py 2.1 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
9
10
11

class GatherCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, index, out):
        if out is not None:
            ctx.mark_dirty(out)
        ctx.src_size = list(src.size())
        ctx.save_for_backward(index)

rusty1s's avatar
rusty1s committed
12
13
14
15
        if src.is_cuda:
            return torch.ops.torch_scatter_cuda.gather_coo(src, index, out)
        else:
            return torch.ops.torch_scatter_cpu.gather_coo(src, index, out)
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22

    @staticmethod
    def backward(ctx, grad_out):
        (index, ), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
            if grad_out.is_cuda:
                grad_src, _ = torch.ops.torch_scatter_cuda.segment_coo(
                    grad_out, index, grad_out.new_zeros(src_size), 'sum')
            else:
                grad_src, _ = torch.ops.torch_scatter_cpu.segment_coo(
                    grad_out, index, grad_out.new_zeros(src_size), 'sum')
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
40

        return grad_src, None, None


class GatherCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out):
        if out is not None:
            ctx.mark_dirty(out)
        ctx.src_size = list(src.size())
        ctx.save_for_backward(indptr)

rusty1s's avatar
rusty1s committed
41
42
43
44
        if src.is_cuda:
            return torch.ops.torch_scatter_cuda.gather_csr(src, indptr, out)
        else:
            return torch.ops.torch_scatter_cpu.gather_csr(src, indptr, out)
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51

    @staticmethod
    def backward(ctx, grad_out):
        (indptr, ), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
            if grad_out.is_cuda:
                grad_src, _ = torch.ops.torch_scatter_cuda.segment_csr(
                    grad_out, indptr, grad_out.new_empty(src_size), 'sum')
            else:
                grad_src, _ = torch.ops.torch_scatter_cpu.segment_csr(
                    grad_out, indptr, grad_out.new_empty(src_size), 'sum')
rusty1s's avatar
rusty1s committed
58
59

        return grad_src, None, None
rusty1s's avatar
rusty1s committed
60
61
62


def gather_coo(src, index, out=None):
rusty1s's avatar
rusty1s committed
63
    return GatherCOO.apply(src, index, out)
rusty1s's avatar
rusty1s committed
64
65
66


def gather_csr(src, indptr, out=None):
rusty1s's avatar
rusty1s committed
67
    return GatherCSR.apply(src, indptr, out)