Commit fefd2cbe authored by rusty1s's avatar rusty1s
Browse files

add flow to knn call

parent 52214143
...@@ -43,8 +43,14 @@ def test_knn_graph(dtype, device): ...@@ -43,8 +43,14 @@ def test_knn_graph(dtype, device):
[+1, -1], [+1, -1],
], dtype, device) ], dtype, device)
row, col = knn_graph(x, k=2) row, col = knn_graph(x, k=2, flow='target_to_source')
col = col.view(-1, 2).sort(dim=-1)[0].view(-1) col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
row, col = knn_graph(x, k=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
...@@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None): ...@@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return torch.stack([row, col], dim=0) return torch.stack([row, col], dim=0)
def knn_graph(x, k, batch=None, loop=False): def knn_graph(x, k, batch=None, loop=False, flow='source_to_target'):
r"""Computes graph edges to the nearest :obj:`k` points. r"""Computes graph edges to the nearest :obj:`k` points.
Args: Args:
...@@ -91,6 +91,9 @@ def knn_graph(x, k, batch=None, loop=False): ...@@ -91,6 +91,9 @@ def knn_graph(x, k, batch=None, loop=False):
node to a specific example. (default: :obj:`None`) node to a specific example. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`) self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -106,10 +109,10 @@ def knn_graph(x, k, batch=None, loop=False): ...@@ -106,10 +109,10 @@ def knn_graph(x, k, batch=None, loop=False):
>>> edge_index = knn_graph(x, k=2, batch=batch, loop=False) >>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
""" """
edge_index = knn(x, x, k if loop else k + 1, batch, batch) assert flow in ['source_to_target', 'target_to_source']
row, col = knn(x, x, k if loop else k + 1, batch, batch)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop: if not loop:
row, col = edge_index
mask = row != col mask = row != col
row, col = row[mask], col[mask] row, col = row[mask], col[mask]
edge_index = torch.stack([row, col], dim=0) return torch.stack([row, col], dim=0)
return edge_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment