You need to sign in or sign up before continuing.
Unverified Commit cc4696bd authored by Piotr Bielak's avatar Piotr Bielak Committed by GitHub
Browse files

Allow returning edge indices from random walk (#139)

This commit adds an optional argument in the `random_walk` function,
namely `return_edge_indices`. The default behaviour is not changed, but
if a user wants to directly use the edges visited by the random walker,
we can return the indices of those edges by setting
`return_edge_indices` to `True`. New cases are also added to the test
suite.
parent c77ed131
...@@ -6,7 +6,7 @@ from .utils import devices, tensor ...@@ -6,7 +6,7 @@ from .utils import devices, tensor
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_rw(device): def test_rw_large(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device) start = tensor([0, 1, 2, 3, 4], torch.long, device)
...@@ -21,6 +21,9 @@ def test_rw(device): ...@@ -21,6 +21,9 @@ def test_rw(device):
assert out[n, i].item() in col[row == cur].tolist() assert out[n, i].item() in col[row == cur].tolist()
cur = out[n, i].item() cur = out[n, i].item()
@pytest.mark.parametrize('device', devices)
def test_rw_small(device):
row = tensor([0, 1], torch.long, device) row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device) col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device) start = tensor([0, 1, 2], torch.long, device)
...@@ -28,3 +31,49 @@ def test_rw(device): ...@@ -28,3 +31,49 @@ def test_rw(device):
out = random_walk(row, col, start, walk_length, num_nodes=3) out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10
node_seq, edge_seq = random_walk(
row, col, start, walk_length,
return_edge_indices=True,
)
assert node_seq[:, 0].tolist() == start.tolist()
for n in range(start.size(0)):
cur = start[n].item()
for i in range(1, walk_length):
assert node_seq[n, i].item() in col[row == cur].tolist()
cur = node_seq[n, i].item()
assert (edge_seq != -1).all()
@pytest.mark.parametrize('device', devices)
def test_rw_small_with_edge_indices(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4
node_seq, edge_seq = random_walk(
row, col, start, walk_length,
num_nodes=3,
return_edge_indices=True,
)
assert node_seq.tolist() == [
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1],
[2, 2, 2, 2, 2],
]
assert edge_seq.tolist() == [
[0, 1, 0, 1],
[1, 0, 1, 0],
[-1, -1, -1, -1],
]
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
@torch.jit.script @torch.jit.script
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, def random_walk(
p: float = 1, q: float = 1, coalesced: bool = True, row: Tensor,
num_nodes: Optional[int] = None) -> Tensor: col: Tensor,
start: Tensor,
walk_length: int,
p: float = 1,
q: float = 1,
coalesced: bool = True,
num_nodes: Optional[int] = None,
return_edge_indices: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Samples random walks of length :obj:`walk_length` from all node indices """Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks" `"node2vec: Scalable Feature Learning for Networks"
...@@ -28,6 +36,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, ...@@ -28,6 +36,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
the graph given by :obj:`(row, col)` according to :obj:`row`. the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`True`) (default: :obj:`True`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
return_edge_indices (bool, optional): Whether to additionally return
the indices of edges traversed during the random walk.
(default: :obj:`False`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
""" """
...@@ -43,5 +54,11 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, ...@@ -43,5 +54,11 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
rowptr = row.new_zeros(num_nodes + 1) rowptr = row.new_zeros(num_nodes + 1)
torch.cumsum(deg, 0, out=rowptr[1:]) torch.cumsum(deg, 0, out=rowptr[1:])
return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length, node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
p, q)[0] rowptr, col, start, walk_length, p, q,
)
if return_edge_indices:
return node_seq, edge_seq
return node_seq
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment