sample.py 1.24 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from typing import Optional, Tuple
rusty1s's avatar
rusty1s committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15

import torch
from torch_sparse.tensor import SparseTensor


def sample(src: SparseTensor, num_neighbors: int,
           subset: Optional[torch.Tensor] = None) -> torch.Tensor:

    rowptr, col, _ = src.csr()
    rowcount = src.storage.rowcount()

    if subset is not None:
        rowcount = rowcount[subset]
        rowptr = rowptr[subset]
rusty1s's avatar
rusty1s committed
16
17
    else:
        rowptr = rowptr[:-1]
rusty1s's avatar
rusty1s committed
18
19

    rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device)
rusty1s's avatar
rusty1s committed
20
21
22
    rand.mul_(rowcount.to(rand.dtype).view(-1, 1))
    rand = rand.to(torch.long)
    rand.add_(rowptr.view(-1, 1))
rusty1s's avatar
rusty1s committed
23
24
25
26

    return col[rand]


rusty1s's avatar
rusty1s committed
27
28
29
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
               replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:

rusty1s's avatar
fixes  
rusty1s committed
30
    rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
31
32

    rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
33
        rowptr, col, subset, num_neighbors, replace)
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
fixes  
rusty1s committed
35
36
37
38
    if value is not None:
        value = value[e_id]

    out = SparseTensor(rowptr=rowptr, row=None, col=col, value=value,
rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
                       sparse_sizes=(subset.size(0), n_id.size(0)),
                       is_sorted=True)

    return out, n_id


rusty1s's avatar
rusty1s committed
45
SparseTensor.sample = sample
rusty1s's avatar
rusty1s committed
46
SparseTensor.sample_adj = sample_adj