sampler.py 404 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
3


rusty1s's avatar
rusty1s committed
4
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
rusty1s's avatar
rusty1s committed
5
6
    assert not start.is_cuda

rusty1s's avatar
rusty1s committed
7
8
9
    factor: float = -1.
    count: int = -1
    if size <= 1:
rusty1s's avatar
rusty1s committed
10
        factor = size
rusty1s's avatar
rusty1s committed
11
12
13
        assert factor > 0
    else:
        count = int(size)
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
    return torch.ops.torch_cluster.neighbor_sampler(start, rowptr, count,
                                                    factor)