Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
66eef3b3
Commit
66eef3b3
authored
Nov 15, 2018
by
rusty1s
Browse files
removed undirected call to pytorch geometric
parent
50909651
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
19 deletions
+7
-19
torch_cluster/sample.py
torch_cluster/sample.py
+7
-19
No files found.
torch_cluster/sample.py
View file @
66eef3b3
import
torch
import
torch
from
torch_geometric.utils
import
to_undirected
from
sample_cuda
import
farthest_point_sampling
,
query_radius
,
query_knn
from
sample_cuda
import
farthest_point_sampling
,
query_radius
,
query_knn
...
@@ -62,8 +61,7 @@ def radius_query_edges(batch,
...
@@ -62,8 +61,7 @@ def radius_query_edges(batch,
query_pos
,
query_pos
,
radius
,
radius
,
max_num_neighbors
=
128
,
max_num_neighbors
=
128
,
include_self
=
True
,
include_self
=
True
):
undirected
=
False
):
if
not
pos
.
is_cuda
:
if
not
pos
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
assert
pos
.
is_cuda
and
batch
.
is_cuda
assert
pos
.
is_cuda
and
batch
.
is_cuda
...
@@ -91,19 +89,13 @@ def radius_query_edges(batch,
...
@@ -91,19 +89,13 @@ def radius_query_edges(batch,
return
col
return
col
edge_index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
edge_index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
undirected
:
return
to_undirected
(
edge_index
,
query_pos
.
size
(
0
))
return
edge_index
return
edge_index
def
radius_graph
(
batch
,
def
radius_graph
(
batch
,
pos
,
radius
,
max_num_neighbors
=
128
,
pos
,
include_self
=
False
):
radius
,
max_num_neighbors
=
128
,
include_self
=
False
,
undirected
=
False
):
return
radius_query_edges
(
batch
,
pos
,
batch
,
pos
,
radius
,
return
radius_query_edges
(
batch
,
pos
,
batch
,
pos
,
radius
,
max_num_neighbors
,
include_self
,
undirected
)
max_num_neighbors
,
include_self
)
def
knn_query_edges
(
batch
,
def
knn_query_edges
(
batch
,
...
@@ -111,8 +103,7 @@ def knn_query_edges(batch,
...
@@ -111,8 +103,7 @@ def knn_query_edges(batch,
query_batch
,
query_batch
,
query_pos
,
query_pos
,
num_neighbors
,
num_neighbors
,
include_self
=
True
,
include_self
=
True
):
undirected
=
False
):
if
not
pos
.
is_cuda
:
if
not
pos
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
assert
pos
.
is_cuda
and
batch
.
is_cuda
assert
pos
.
is_cuda
and
batch
.
is_cuda
...
@@ -140,11 +131,8 @@ def knn_query_edges(batch,
...
@@ -140,11 +131,8 @@ def knn_query_edges(batch,
col
=
view
[
view
!=
-
1
]
col
=
view
[
view
!=
-
1
]
edge_index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
edge_index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
undirected
:
return
to_undirected
(
edge_index
,
query_pos
.
size
(
0
))
return
edge_index
return
edge_index
def
knn_graph
(
batch
,
pos
,
num_neighbors
,
include_self
=
False
,
undirected
=
False
):
def
knn_graph
(
batch
,
pos
,
num_neighbors
,
include_self
=
False
):
return
knn_query_edges
(
batch
,
pos
,
batch
,
pos
,
num_neighbors
,
include_self
,
return
knn_query_edges
(
batch
,
pos
,
batch
,
pos
,
num_neighbors
,
include_self
)
undirected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment