"src/libtorio/ffmpeg/pybind/pybind.cpp" did not exist on "59f067b78838ef49b8b8399496b2a745ad9b2b92"
Commit e3c3b133 authored by rusty1s's avatar rusty1s
Browse files

flow arg for radius

parent 4047c05d
...@@ -28,7 +28,7 @@ if CUDA_HOME is not None: ...@@ -28,7 +28,7 @@ if CUDA_HOME is not None:
['cuda/rw.cpp', 'cuda/rw_kernel.cu']), ['cuda/rw.cpp', 'cuda/rw_kernel.cu']),
] ]
__version__ = '1.4.1' __version__ = '1.4.2'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy'] install_requires = ['scipy']
......
...@@ -45,12 +45,10 @@ def test_knn_graph(dtype, device): ...@@ -45,12 +45,10 @@ def test_knn_graph(dtype, device):
row, col = knn_graph(x, k=2, flow='target_to_source') row, col = knn_graph(x, k=2, flow='target_to_source')
col = col.view(-1, 2).sort(dim=-1)[0].view(-1) col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
row, col = knn_graph(x, k=2, flow='source_to_target') row, col = knn_graph(x, k=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1) row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2] assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
...@@ -47,6 +47,12 @@ def test_radius_graph(dtype, device): ...@@ -47,6 +47,12 @@ def test_radius_graph(dtype, device):
[+1, -1], [+1, -1],
], dtype, device) ], dtype, device)
out = radius_graph(x, r=2) row, col = radius_graph(x, r=2, flow='target_to_source')
assert coalesce(out).tolist() == [[0, 0, 1, 1, 2, 2, 3, 3], col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
[1, 3, 0, 2, 1, 3, 0, 2]] assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
row, col = radius_graph(x, r=2, flow='source_to_target')
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
...@@ -7,7 +7,7 @@ from .radius import radius, radius_graph ...@@ -7,7 +7,7 @@ from .radius import radius, radius_graph
from .sampler import neighbor_sampler from .sampler import neighbor_sampler
from .rw import random_walk from .rw import random_walk
__version__ = '1.4.1' __version__ = '1.4.2'
__all__ = [ __all__ = [
'graclus_cluster', 'graclus_cluster',
......
...@@ -73,7 +73,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): ...@@ -73,7 +73,12 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
return torch.stack([row[mask], col[mask]], dim=0) return torch.stack([row[mask], col[mask]], dim=0)
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): def radius_graph(x,
r,
batch=None,
loop=False,
max_num_neighbors=32,
flow='source_to_target'):
r"""Computes graph edges to all points within a given distance. r"""Computes graph edges to all points within a given distance.
Args: Args:
...@@ -87,6 +92,9 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): ...@@ -87,6 +92,9 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
self-loops. (default: :obj:`False`) self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`) return for each element in :obj:`y`. (default: :obj:`32`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -102,11 +110,10 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32): ...@@ -102,11 +110,10 @@ def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
>>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False) >>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
""" """
edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1) assert flow in ['source_to_target', 'target_to_source']
row, col = edge_index row, col = radius(x, x, r, batch, batch, max_num_neighbors + 1)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop: if not loop:
row, col = edge_index
mask = row != col mask = row != col
row, col = row[mask], col[mask] row, col = row[mask], col[mask]
edge_index = torch.stack([row, col], dim=0) return torch.stack([row, col], dim=0)
return edge_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment