Unverified Commit cc4696bd authored by Piotr Bielak's avatar Piotr Bielak Committed by GitHub
Browse files

Allow returning edge indices from random walk (#139)

This commit adds an optional argument in the `random_walk` function,
namely `return_edge_indices`. The default behaviour is not changed, but
if a user wants to directly use the edges visited by the random walker,
we can return the indices of those edges by setting
`return_edge_indices` to `True`. New cases are also added to the test
suite.
parent c77ed131
...@@ -6,7 +6,7 @@ from .utils import devices, tensor ...@@ -6,7 +6,7 @@ from .utils import devices, tensor
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_rw(device): def test_rw_large(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device) start = tensor([0, 1, 2, 3, 4], torch.long, device)
...@@ -21,6 +21,9 @@ def test_rw(device): ...@@ -21,6 +21,9 @@ def test_rw(device):
assert out[n, i].item() in col[row == cur].tolist() assert out[n, i].item() in col[row == cur].tolist()
cur = out[n, i].item() cur = out[n, i].item()
@pytest.mark.parametrize('device', devices)
def test_rw_small(device):
row = tensor([0, 1], torch.long, device) row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device) col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device) start = tensor([0, 1, 2], torch.long, device)
...@@ -28,3 +31,49 @@ def test_rw(device): ...@@ -28,3 +31,49 @@ def test_rw(device):
out = random_walk(row, col, start, walk_length, num_nodes=3) out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10
node_seq, edge_seq = random_walk(
row, col, start, walk_length,
return_edge_indices=True,
)
assert node_seq[:, 0].tolist() == start.tolist()
for n in range(start.size(0)):
cur = start[n].item()
for i in range(1, walk_length):
assert node_seq[n, i].item() in col[row == cur].tolist()
cur = node_seq[n, i].item()
assert (edge_seq != -1).all()
@pytest.mark.parametrize('device', devices)
def test_rw_small_with_edge_indices(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4
node_seq, edge_seq = random_walk(
row, col, start, walk_length,
num_nodes=3,
return_edge_indices=True,
)
assert node_seq.tolist() == [
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1],
[2, 2, 2, 2, 2],
]
assert edge_seq.tolist() == [
[0, 1, 0, 1],
[1, 0, 1, 0],
[-1, -1, -1, -1],
]
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
@torch.jit.script @torch.jit.script
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, def random_walk(
p: float = 1, q: float = 1, coalesced: bool = True, row: Tensor,
num_nodes: Optional[int] = None) -> Tensor: col: Tensor,
start: Tensor,
walk_length: int,
p: float = 1,
q: float = 1,
coalesced: bool = True,
num_nodes: Optional[int] = None,
return_edge_indices: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Samples random walks of length :obj:`walk_length` from all node indices """Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks" `"node2vec: Scalable Feature Learning for Networks"
...@@ -28,6 +36,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, ...@@ -28,6 +36,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
the graph given by :obj:`(row, col)` according to :obj:`row`. the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`True`) (default: :obj:`True`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
return_edge_indices (bool, optional): Whether to additionally return
the indices of edges traversed during the random walk.
(default: :obj:`False`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
""" """
...@@ -43,5 +54,11 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, ...@@ -43,5 +54,11 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
rowptr = row.new_zeros(num_nodes + 1) rowptr = row.new_zeros(num_nodes + 1)
torch.cumsum(deg, 0, out=rowptr[1:]) torch.cumsum(deg, 0, out=rowptr[1:])
return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length, node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
p, q)[0] rowptr, col, start, walk_length, p, q,
)
if return_edge_indices:
return node_seq, edge_seq
return node_seq
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment