sampler.py 422 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
3


rusty1s's avatar
rusty1s committed
4
5
@torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
rusty1s's avatar
rusty1s committed
6
7
    assert not start.is_cuda

rusty1s's avatar
rusty1s committed
8
9
10
    factor: float = -1.
    count: int = -1
    if size <= 1:
rusty1s's avatar
rusty1s committed
11
        factor = size
rusty1s's avatar
rusty1s committed
12
13
14
        assert factor > 0
    else:
        count = int(size)
rusty1s's avatar
rusty1s committed
15

rusty1s's avatar
rusty1s committed
16
17
    return torch.ops.torch_cluster.neighbor_sampler(start, rowptr, count,
                                                    factor)