gather.py 1.5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch

if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    from torch_scatter import gather_cuda, segment_cuda


class GatherCOO(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, index, out):
        if out is not None:
            ctx.mark_dirty(out)
        ctx.src_size = list(src.size())
        ctx.save_for_backward(index)

        return gather_cuda.gather_coo(src, index, out)

    @staticmethod
    def backward(ctx, grad_out):
        (index, ), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
            grad_src, _ = segment_cuda.segment_coo(
                grad_out, index, grad_out.new_zeros(src_size), 'add')

        return grad_src, None, None


class GatherCSR(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src, indptr, out):
        if out is not None:
            ctx.mark_dirty(out)
        ctx.src_size = list(src.size())
        ctx.save_for_backward(indptr)

        return gather_cuda.gather_csr(src, indptr, out)

    @staticmethod
    def backward(ctx, grad_out):
        (indptr, ), src_size = ctx.saved_tensors, ctx.src_size

        grad_src = None
        if ctx.needs_input_grad[0]:
            grad_src, _ = segment_cuda.segment_csr(
                grad_out, indptr, grad_out.new_empty(src_size), 'add')

        return grad_src, None, None
rusty1s's avatar
rusty1s committed
49
50
51


def gather_coo(src, index, out=None):
rusty1s's avatar
rusty1s committed
52
    return GatherCOO.apply(src, index, out)
rusty1s's avatar
rusty1s committed
53
54
55


def gather_csr(src, indptr, out=None):
rusty1s's avatar
rusty1s committed
56
    return GatherCSR.apply(src, indptr, out)