gather.py 270 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
import torch

if torch.cuda.is_available():
    from torch_scatter import gather_cuda


def gather_coo(src, index, out=None):
    return gather_cuda.gather_coo(src, index, out)


def gather_csr(src, indptr, out=None):
    return gather_cuda.gather_csr(src, indptr, out)