"tests/vscode:/vscode.git/clone" did not exist on "571340dac63d3a09e5d66d45244f9f13bb175d00"
Commit f371e49e authored by rusty1s's avatar rusty1s
Browse files

self-loops flag

parent 377ad11e
......@@ -64,7 +64,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
auto dist = at::full(y.size(0) * k, 1e38, y.options());
auto row = at::empty(y.size(0) * k, batch_y.options());
auto col = at::empty(y.size(0) * k, batch_y.options());
auto col = at::full(y.size(0) * k, -1, batch_y.options());
AT_DISPATCH_FLOATING_TYPES(x.type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<batch_size, THREADS>>>(
......@@ -73,5 +73,6 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
col.data<int64_t>(), k, x.size(1));
});
return at::stack({row, col}, 0);
auto mask = col != -1;
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
......@@ -51,7 +51,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return assign_index
def knn_graph(x, k, batch=None):
def knn_graph(x, k, batch=None, loop=False):
"""Finds for each element in `x` the `k` nearest points.
Args:
......@@ -62,6 +62,8 @@ def knn_graph(x, k, batch=None):
example. If not :obj:`None`, points in the same example need to
have contiguous memory layout and :obj:`batch` needs to be
ascending. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
:rtype: :class:`LongTensor`
......@@ -72,8 +74,10 @@ def knn_graph(x, k, batch=None):
>>> out = knn_graph(x, 2, batch)
"""
edge_index = knn(x, x, k + 1, batch, batch)
edge_index = knn(x, x, k if loop else k + 1, batch, batch)
if not loop:
row, col = edge_index
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)
edge_index = torch.stack([row, col], dim=0)
return edge_index
......@@ -53,7 +53,7 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
return assign_index
def radius_graph(x, r, batch=None, max_num_neighbors=32):
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
"""Finds for each element in `x` all points in `x` within distance `r`.
Args:
......@@ -64,6 +64,8 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32):
example. If not :obj:`None`, points in the same example need to
have contiguous memory layout and :obj:`batch` needs to be
ascending. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in `y`. (default: :obj:`32`)
......@@ -77,7 +79,10 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32):
"""
edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1)
row, col = edge_index
if not loop:
row, col = edge_index
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)
edge_index = torch.stack([row, col], dim=0)
return edge_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment