Commit e3c3b133 authored by rusty1s's avatar rusty1s
Browse files

flow arg for radius

parent 4047c05d
......@@ -28,7 +28,7 @@ if CUDA_HOME is not None:
['cuda/rw.cpp', 'cuda/rw_kernel.cu']),
]
__version__ = '1.4.1'
__version__ = '1.4.2'
url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy']
......
......@@ -45,12 +45,10 @@ def test_knn_graph(dtype, device):
row, col = knn_graph(x, k=2, flow='target_to_source')
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
row, col = knn_graph(x, k=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
......@@ -47,6 +47,12 @@ def test_radius_graph(dtype, device):
[+1, -1],
], dtype, device)
out = radius_graph(x, r=2)
assert coalesce(out).tolist() == [[0, 0, 1, 1, 2, 2, 3, 3],
[1, 3, 0, 2, 1, 3, 0, 2]]
row, col = radius_graph(x, r=2, flow='target_to_source')
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
row, col = radius_graph(x, r=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
......@@ -7,7 +7,7 @@ from .radius import radius, radius_graph
from .sampler import neighbor_sampler
from .rw import random_walk
__version__ = '1.4.1'
__version__ = '1.4.2'
__all__ = [
'graclus_cluster',
......
......@@ -73,7 +73,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
return torch.stack([row[mask], col[mask]], dim=0)
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
def radius_graph(x,
r,
batch=None,
loop=False,
max_num_neighbors=32,
flow='source_to_target'):
r"""Computes graph edges to all points within a given distance.
Args:
......@@ -87,6 +92,9 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:rtype: :class:`LongTensor`
......@@ -102,11 +110,10 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
>>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
"""
edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1)
row, col = edge_index
assert flow in ['source_to_target', 'target_to_source']
row, col = radius(x, x, r, batch, batch, max_num_neighbors + 1)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop:
row, col = edge_index
mask = row != col
row, col = row[mask], col[mask]
edge_index = torch.stack([row, col], dim=0)
return edge_index
return torch.stack([row, col], dim=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment