Commit d8bec425 authored by rusty1s's avatar rusty1s
Browse files

sample method

parent 6af9980d
from typing import Optional
import torch
from torch_sparse.tensor import SparseTensor
def sample(src: SparseTensor, num_neighbors: int,
subset: Optional[torch.Tensor] = None) -> torch.Tensor:
rowptr, col, _ = src.csr()
rowcount = src.storage.rowcount()
if subset is not None:
rowcount = rowcount[subset]
rowptr = rowptr[subset]
rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device)
rand = rand.mul_(rowcount.to(rand.dtype)).to(torch.long).add_(rowptr)
return col[rand]
SparseTensor.sample = sample
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment