Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
f371e49e
Commit
f371e49e
authored
Nov 27, 2018
by
rusty1s
Browse files
self-loops flag
parent
377ad11e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
12 deletions
+22
-12
cuda/knn_kernel.cu
cuda/knn_kernel.cu
+3
-2
torch_cluster/knn.py
torch_cluster/knn.py
+10
-6
torch_cluster/radius.py
torch_cluster/radius.py
+9
-4
No files found.
cuda/knn_kernel.cu
View file @
f371e49e
...
@@ -64,7 +64,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
...
@@ -64,7 +64,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
auto
dist
=
at
::
full
(
y
.
size
(
0
)
*
k
,
1e38
,
y
.
options
());
auto
dist
=
at
::
full
(
y
.
size
(
0
)
*
k
,
1e38
,
y
.
options
());
auto
row
=
at
::
empty
(
y
.
size
(
0
)
*
k
,
batch_y
.
options
());
auto
row
=
at
::
empty
(
y
.
size
(
0
)
*
k
,
batch_y
.
options
());
auto
col
=
at
::
empty
(
y
.
size
(
0
)
*
k
,
batch_y
.
options
());
auto
col
=
at
::
full
(
y
.
size
(
0
)
*
k
,
-
1
,
batch_y
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"knn_kernel"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"knn_kernel"
,
[
&
]
{
knn_kernel
<
scalar_t
><<<
batch_size
,
THREADS
>>>
(
knn_kernel
<
scalar_t
><<<
batch_size
,
THREADS
>>>
(
...
@@ -73,5 +73,6 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
...
@@ -73,5 +73,6 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
col
.
data
<
int64_t
>
(),
k
,
x
.
size
(
1
));
col
.
data
<
int64_t
>
(),
k
,
x
.
size
(
1
));
});
});
return
at
::
stack
({
row
,
col
},
0
);
auto
mask
=
col
!=
-
1
;
return
at
::
stack
({
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
)},
0
);
}
}
torch_cluster/knn.py
View file @
f371e49e
...
@@ -51,7 +51,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
...
@@ -51,7 +51,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
return
assign_index
return
assign_index
def
knn_graph
(
x
,
k
,
batch
=
None
):
def
knn_graph
(
x
,
k
,
batch
=
None
,
loop
=
False
):
"""Finds for each element in `x` the `k` nearest points.
"""Finds for each element in `x` the `k` nearest points.
Args:
Args:
...
@@ -62,6 +62,8 @@ def knn_graph(x, k, batch=None):
...
@@ -62,6 +62,8 @@ def knn_graph(x, k, batch=None):
example. If not :obj:`None`, points in the same example need to
example. If not :obj:`None`, points in the same example need to
have contiguous memory layout and :obj:`batch` needs to be
have contiguous memory layout and :obj:`batch` needs to be
ascending. (default: :obj:`None`)
ascending. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -72,8 +74,10 @@ def knn_graph(x, k, batch=None):
...
@@ -72,8 +74,10 @@ def knn_graph(x, k, batch=None):
>>> out = knn_graph(x, 2, batch)
>>> out = knn_graph(x, 2, batch)
"""
"""
edge_index
=
knn
(
x
,
x
,
k
+
1
,
batch
,
batch
)
edge_index
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
)
row
,
col
=
edge_index
if
not
loop
:
mask
=
row
!=
col
row
,
col
=
edge_index
row
,
col
=
row
[
mask
],
col
[
mask
]
mask
=
row
!=
col
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
row
,
col
=
row
[
mask
],
col
[
mask
]
edge_index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
edge_index
torch_cluster/radius.py
View file @
f371e49e
...
@@ -53,7 +53,7 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
...
@@ -53,7 +53,7 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
return
assign_index
return
assign_index
def
radius_graph
(
x
,
r
,
batch
=
None
,
max_num_neighbors
=
32
):
def
radius_graph
(
x
,
r
,
batch
=
None
,
loop
=
False
,
max_num_neighbors
=
32
):
"""Finds for each element in `x` all points in `x` within distance `r`.
"""Finds for each element in `x` all points in `x` within distance `r`.
Args:
Args:
...
@@ -64,6 +64,8 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32):
...
@@ -64,6 +64,8 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32):
example. If not :obj:`None`, points in the same example need to
example. If not :obj:`None`, points in the same example need to
have contiguous memory layout and :obj:`batch` needs to be
have contiguous memory layout and :obj:`batch` needs to be
ascending. (default: :obj:`None`)
ascending. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in `y`. (default: :obj:`32`)
return for each element in `y`. (default: :obj:`32`)
...
@@ -78,6 +80,9 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32):
...
@@ -78,6 +80,9 @@ def radius_graph(x, r, batch=None, max_num_neighbors=32):
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
+
1
)
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
+
1
)
row
,
col
=
edge_index
row
,
col
=
edge_index
mask
=
row
!=
col
if
not
loop
:
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
edge_index
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
edge_index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
edge_index
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment