Unverified Commit 723039b8 authored by Padarn Wilson's avatar Padarn Wilson Committed by GitHub
Browse files

Add tests for neighbor sampling (#210)



* fixing neighbour sampling when repeated nodes

* fix problem with indexing

* reset

* reset

* reset

* add tests
Co-authored-by: default avatarrusty1s <matthias.fey@tu-dortmund.de>
parent b38d0b5e
import torch
from torch_sparse import SparseTensor
neighbor_sample = torch.ops.torch_sparse.neighbor_sample
def test_neighbor_sample():
adj = SparseTensor.from_edge_index(torch.tensor([[0], [1]]))
colptr, row, _ = adj.csc()
# Sampling in a non-directed way should not sample in wrong direction:
out = neighbor_sample(colptr, row, torch.tensor([0]), [1], False, False)
assert out[0].tolist() == [0]
assert out[1].tolist() == []
assert out[2].tolist() == []
# Sampling should work:
out = neighbor_sample(colptr, row, torch.tensor([1]), [1], False, False)
assert out[0].tolist() == [1, 0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0]
# Sampling with more hops:
out = neighbor_sample(colptr, row, torch.tensor([1]), [1, 1], False, False)
assert out[0].tolist() == [1, 0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment