Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
b75a19e8
Unverified
Commit
b75a19e8
authored
Jun 22, 2020
by
Matthias Fey
Committed by
GitHub
Jun 22, 2020
Browse files
Merge pull request #68 from liaopeiyuan/cpu_radius
C++ CPU for radius and knn
parents
32fa3257
29f97162
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
132 deletions
+93
-132
torch_cluster/knn.py
torch_cluster/knn.py
+48
-71
torch_cluster/radius.py
torch_cluster/radius.py
+45
-61
No files found.
torch_cluster/knn.py
View file @
b75a19e8
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
scipy.spatial
@
torch
.
jit
.
script
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
cosine
:
bool
=
False
)
->
torch
.
Tensor
:
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
:obj:`x`.
...
@@ -19,13 +19,18 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -19,13 +19,18 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
k (int): The number of neighbors.
k (int): The number of neighbors.
batch_x (LongTensor, optional): Batch vector
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_y` needs to be sorted.
cosine (boolean, optional): If :obj:`True`, will use the cosine
(default: :obj:`None`)
distance instead of euclidean distance to find nearest neighbors.
cosine (boolean, optional): If :obj:`True`, will use the Cosine
(default: :obj:`False`)
distance instead of the Euclidean distance to find nearest
neighbors. (default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -43,77 +48,38 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -43,77 +48,38 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
if
x
.
is_cuda
:
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
ptr_x
=
torch
.
tensor
([
0
,
x
.
size
(
0
)],
device
=
x
.
device
)
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
)
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
else
:
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
if
batch_x
is
None
:
batch_x
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
None
:
batch_y
=
y
.
new_zeros
(
y
.
size
(
0
),
dtype
=
torch
.
long
)
assert
x
.
dim
()
==
2
and
batch_x
.
dim
()
==
1
assert
y
.
dim
()
==
2
and
batch_y
.
dim
()
==
1
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
if
cosine
:
raise
NotImplementedError
(
'`cosine` argument not supported on CPU'
)
# Translate and rescale x and y to [0, 1].
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
min_xy
=
min
(
x
.
min
().
item
(),
y
.
min
().
item
())
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
x
,
y
=
x
-
min_xy
,
y
-
min_xy
max_xy
=
max
(
x
.
max
().
item
(),
y
.
max
().
item
())
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
x
.
div_
(
max_xy
)
if
batch_y
is
not
None
:
y
.
div_
(
max_xy
)
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
# Concat batch/features to ensure no cross-links between examples.
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
x
=
torch
.
cat
([
x
,
2
*
x
.
size
(
1
)
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
-
1
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
y
=
torch
.
cat
([
y
,
2
*
y
.
size
(
1
)
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
-
1
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
dist
,
col
=
tree
.
query
(
y
.
detach
().
cpu
(),
k
=
k
,
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
distance_upper_bound
=
x
.
size
(
1
))
dist
=
torch
.
from_numpy
(
dist
).
to
(
x
.
dtype
)
col
=
torch
.
from_numpy
(
col
).
to
(
torch
.
long
)
row
=
torch
.
arange
(
col
.
size
(
0
),
dtype
=
torch
.
long
)
row
=
row
.
view
(
-
1
,
1
).
repeat
(
1
,
k
)
mask
=
~
torch
.
isinf
(
dist
).
view
(
-
1
)
row
,
col
=
row
.
view
(
-
1
)[
mask
],
col
.
view
(
-
1
)[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
num_workers
)
@
torch
.
jit
.
script
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
cosine
:
bool
=
False
)
->
torch
.
Tensor
:
cosine
:
bool
=
False
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to the nearest :obj:`k` points.
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
Args:
...
@@ -122,7 +88,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -122,7 +88,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
k (int): The number of neighbors.
k (int): The number of neighbors.
batch (LongTensor, optional): Batch vector
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when using in combination
flow (string, optional): The flow direction when using in combination
...
@@ -131,6 +98,9 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -131,6 +98,9 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
cosine (boolean, optional): If :obj:`True`, will use the cosine
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
(default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -145,9 +115,16 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -145,9 +115,16 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
"""
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
=
cosine
)
edge_index
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
num_workers
)
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
else
:
row
,
col
=
edge_index
[
0
],
edge_index
[
1
]
if
not
loop
:
if
not
loop
:
mask
=
row
!=
col
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
row
[
mask
],
col
[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
torch_cluster/radius.py
View file @
b75a19e8
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
scipy.spatial
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
sample
(
col
:
torch
.
Tensor
,
count
:
int
)
->
torch
.
Tensor
:
if
col
.
size
(
0
)
>
count
:
col
=
col
[
torch
.
randperm
(
col
.
size
(
0
),
dtype
=
torch
.
long
)][:
count
]
return
col
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
max_num_neighbo
rs
:
int
=
32
)
->
torch
.
Tensor
:
num_worke
rs
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
distance :obj:`r`.
...
@@ -26,12 +19,17 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -26,12 +19,17 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
r (float): The radius.
r (float): The radius.
batch_x (LongTensor, optional): Batch vector
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_y` needs to be sorted.
(default: :obj:`None`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
return for each element in :obj:`y`. (default: :obj:`32`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
.. code-block:: python
.. code-block:: python
...
@@ -47,63 +45,39 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -47,63 +45,39 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
if
x
.
is_cuda
:
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
ptr_x
=
torch
.
tensor
([
0
,
x
.
size
(
0
)],
device
=
x
.
device
)
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
max_num_neighbors
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
if
batch_x
is
None
:
batch_x
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
None
:
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
batch_y
=
y
.
new_zeros
(
y
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
assert
x
.
dim
()
==
2
and
batch_x
.
dim
()
==
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
assert
y
.
dim
()
==
2
and
batch_y
.
dim
()
==
1
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
x
=
torch
.
cat
([
x
,
2
*
r
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
dim
=-
1
)
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
y
=
torch
.
c
at
([
y
,
2
*
r
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
dim
=-
1
)
torch
.
c
umsum
(
deg
,
0
,
out
=
ptr_y
[
1
:]
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
col
=
tree
.
query_ball_point
(
y
.
detach
().
numpy
(),
r
)
max_num_neighbors
,
num_workers
)
col
=
[
torch
.
tensor
(
c
,
dtype
=
torch
.
long
)
for
c
in
col
]
col
=
[
sample
(
c
,
max_num_neighbors
)
for
c
in
col
]
row
=
[
torch
.
full_like
(
c
,
i
)
for
i
,
c
in
enumerate
(
col
)]
row
,
col
=
torch
.
cat
(
row
,
dim
=
0
),
torch
.
cat
(
col
,
dim
=
0
)
mask
=
col
<
int
(
tree
.
n
)
return
torch
.
stack
([
row
[
mask
],
col
[
mask
]],
dim
=
0
)
@
torch
.
jit
.
script
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
flow
:
str
=
'source_to_target'
)
->
torch
.
Tensor
:
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to all points within a given distance.
r
"""Computes graph edges to all points within a given distance.
Args:
Args:
...
@@ -112,7 +86,8 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -112,7 +86,8 @@ def radius_graph(x: torch.Tensor, r: float,
r (float): The radius.
r (float): The radius.
batch (LongTensor, optional): Batch vector
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
...
@@ -120,6 +95,9 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -120,6 +95,9 @@ def radius_graph(x: torch.Tensor, r: float,
flow (string, optional): The flow direction when using in combination
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -134,10 +112,16 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -134,10 +112,16 @@ def radius_graph(x: torch.Tensor, r: float,
"""
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
)
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
num_workers
)
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
else
:
row
,
col
=
edge_index
[
0
],
edge_index
[
1
]
if
not
loop
:
if
not
loop
:
mask
=
row
!=
col
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
row
[
mask
],
col
[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment