gather.py 1.71 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
4
from torch_scatter import segment_cpu, gather_cpu

rusty1s's avatar
rusty1s committed
5
if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
6
7
    from torch_scatter import gather_cuda, segment_cuda

rusty1s's avatar
rusty1s committed
8
9
10
gat = lambda is_cuda: gather_cuda if is_cuda else gather_cpu  # noqa
seg = lambda is_cuda: segment_cuda if is_cuda else segment_cpu  # noqa

rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19

class GatherCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, index, out):
        if out is not None:
            ctx.mark_dirty(out)
        ctx.src_size = list(src.size())
        ctx.save_for_backward(index)

rusty1s's avatar
rusty1s committed
20
        return gat(src.is_cuda).gather_coo(src, index, out)
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
27

    @staticmethod
    def backward(ctx, grad_out):
        (index, ), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
28
            grad_src, _ = seg(grad_out.is_cuda).segment_coo(
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
40
41
                grad_out, index, grad_out.new_zeros(src_size), 'add')

        return grad_src, None, None


class GatherCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out):
        if out is not None:
            ctx.mark_dirty(out)
        ctx.src_size = list(src.size())
        ctx.save_for_backward(indptr)

rusty1s's avatar
rusty1s committed
42
        return gat(src.is_cuda).gather_csr(src, indptr, out)
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49

    @staticmethod
    def backward(ctx, grad_out):
        (indptr, ), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
rusty1s's avatar
rusty1s committed
50
            grad_src, _ = seg(grad_out.is_cuda).segment_csr(
rusty1s's avatar
rusty1s committed
51
52
53
                grad_out, indptr, grad_out.new_empty(src_size), 'add')

        return grad_src, None, None
rusty1s's avatar
rusty1s committed
54
55
56


def gather_coo(src, index, out=None):
rusty1s's avatar
rusty1s committed
57
    return GatherCOO.apply(src, index, out)
rusty1s's avatar
rusty1s committed
58
59
60


def gather_csr(src, indptr, out=None):
rusty1s's avatar
rusty1s committed
61
    return GatherCSR.apply(src, indptr, out)