Commit f371e49e authored by rusty1s's avatar rusty1s
Browse files

self-loops flag

parent 377ad11e
...@@ -64,7 +64,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x, ...@@ -64,7 +64,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
auto dist = at::full(y.size(0) * k, 1e38, y.options()); auto dist = at::full(y.size(0) * k, 1e38, y.options());
auto row = at::empty(y.size(0) * k, batch_y.options()); auto row = at::empty(y.size(0) * k, batch_y.options());
auto col = at::empty(y.size(0) * k, batch_y.options()); auto col = at::full(y.size(0) * k, -1, batch_y.options());
AT_DISPATCH_FLOATING_TYPES(x.type(), "knn_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<batch_size, THREADS>>>( knn_kernel<scalar_t><<<batch_size, THREADS>>>(
...@@ -73,5 +73,6 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x, ...@@ -73,5 +73,6 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
col.data<int64_t>(), k, x.size(1)); col.data<int64_t>(), k, x.size(1));
}); });
return at::stack({row, col}, 0); auto mask = col != -1;
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
} }
...@@ -51,7 +51,7 @@ def knn(x, y, k, batch_x=None, batch_y=None): ...@@ -51,7 +51,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return assign_index return assign_index
def knn_graph(x, k, batch=None): def knn_graph(x, k, batch=None, loop=False):
"""Finds for each element in `x` the `k` nearest points. """Finds for each element in `x` the `k` nearest points.
Args: Args:
...@@ -62,6 +62,8 @@ def knn_graph(x, k, batch=None): ...@@ -62,6 +62,8 @@ def knn_graph(x, k, batch=None):
example. If not :obj:`None`, points in the same example need to example. If not :obj:`None`, points in the same example need to
have contiguous memory layout and :obj:`batch` needs to be have contiguous memory layout and :obj:`batch` needs to be
ascending. (default: :obj:`None`) ascending. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -72,8 +74,10 @@ def knn_graph(x, k, batch=None): ...@@ -72,8 +74,10 @@ def knn_graph(x, k, batch=None):
>>> out = knn_graph(x, 2, batch) >>> out = knn_graph(x, 2, batch)
""" """
edge_index = knn(x, x, k + 1, batch, batch) edge_index = knn(x, x, k if loop else k + 1, batch, batch)
row, col = edge_index if not loop:
mask = row != col row, col = edge_index
row, col = row[mask], col[mask] mask = row != col
return torch.stack([row, col], dim=0) row, col = row[mask], col[mask]
edge_index = torch.stack([row, col], dim=0)
return edge_index
...@@ -53,7 +53,7 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): ...@@ -53,7 +53,7 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
return assign_index return assign_index
def radius_graph(x, r, batch=None, max_num_neighbors=32): def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
"""Finds for each element in `x` all points in `x` within distance `r`. """Finds for each element in `x` all points in `x` within distance `r`.
Args: Args:
...@@ -64,6 +64,8 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32): ...@@ -64,6 +64,8 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32):
example. If not :obj:`None`, points in the same example need to example. If not :obj:`None`, points in the same example need to
have contiguous memory layout and :obj:`batch` needs to be have contiguous memory layout and :obj:`batch` needs to be
ascending. (default: :obj:`None`) ascending. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in `y`. (default: :obj:`32`) return for each element in `y`. (default: :obj:`32`)
...@@ -78,6 +80,9 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32): ...@@ -78,6 +80,9 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32):
edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1) edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1)
row, col = edge_index row, col = edge_index
mask = row != col if not loop:
row, col = row[mask], col[mask] row, col = edge_index
return torch.stack([row, col], dim=0) mask = row != col
row, col = row[mask], col[mask]
edge_index = torch.stack([row, col], dim=0)
return edge_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment