sample.py 618 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import Optional

import torch
from torch_sparse.tensor import SparseTensor


def sample(src: SparseTensor, num_neighbors: int,
           subset: Optional[torch.Tensor] = None) -> torch.Tensor:

    rowptr, col, _ = src.csr()
    rowcount = src.storage.rowcount()

    if subset is not None:
        rowcount = rowcount[subset]
        rowptr = rowptr[subset]

    rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device)
rusty1s's avatar
rusty1s committed
18
19
20
    rand.mul_(rowcount.to(rand.dtype).view(-1, 1))
    rand = rand.to(torch.long)
    rand.add_(rowptr.view(-1, 1))
rusty1s's avatar
rusty1s committed
21
22
23
24
25

    return col[rand]


SparseTensor.sample = sample