Commit 23709f94 authored by rusty1s's avatar rusty1s
Browse files

add test

parent 709f6837
import torch import torch
from torch_sparse import SparseTensor, sample_adj from torch_sparse import SparseTensor, sample, sample_adj
def test_sample():
row = torch.tensor([0, 0, 2, 2])
col = torch.tensor([1, 2, 0, 1])
adj = SparseTensor(row=row, col=col, sparse_sizes=(3, 3))
out = sample(adj, num_neighbors=1)
assert out.min() >= 0 and out.max() <= 2
def test_sample_adj(): def test_sample_adj():
......
...@@ -13,6 +13,8 @@ def sample(src: SparseTensor, num_neighbors: int, ...@@ -13,6 +13,8 @@ def sample(src: SparseTensor, num_neighbors: int,
if subset is not None: if subset is not None:
rowcount = rowcount[subset] rowcount = rowcount[subset]
rowptr = rowptr[subset] rowptr = rowptr[subset]
else:
rowptr = rowptr[:-1]
rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device) rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device)
rand.mul_(rowcount.to(rand.dtype).view(-1, 1)) rand.mul_(rowcount.to(rand.dtype).view(-1, 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment