Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
1abbba60
Commit
1abbba60
authored
Jul 16, 2020
by
rusty1s
Browse files
add nvcc flags [ci skip]
parents
f127bd3a
c8e167a7
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
100 additions
and
139 deletions
+100
-139
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
torch_cluster/knn.py
torch_cluster/knn.py
+52
-75
torch_cluster/radius.py
torch_cluster/radius.py
+47
-63
No files found.
torch_cluster/__init__.py
View file @
1abbba60
...
@@ -3,7 +3,7 @@ import os.path as osp
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
import
torch
__version__
=
'1.5.
4
'
__version__
=
'1.5.
5
'
for
library
in
[
for
library
in
[
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
,
'_nearest'
,
'_version'
,
'_grid'
,
'_graclus'
,
'_fps'
,
'_rw'
,
'_sampler'
,
'_nearest'
,
...
...
torch_cluster/knn.py
View file @
1abbba60
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
scipy.spatial
@
torch
.
jit
.
script
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
cosine
:
bool
=
False
)
->
torch
.
Tensor
:
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
:obj:`x`.
...
@@ -19,13 +19,18 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -19,13 +19,18 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
k (int): The number of neighbors.
k (int): The number of neighbors.
batch_x (LongTensor, optional): Batch vector
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_y` needs to be sorted.
cosine (boolean, optional): If :obj:`True`, will use the cosine
(default: :obj:`None`)
distance instead of euclidean distance to find nearest neighbors.
cosine (boolean, optional): If :obj:`True`, will use the Cosine
(default: :obj:`False`)
distance instead of the Euclidean distance to find nearest
neighbors. (default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -37,83 +42,44 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -37,83 +42,44 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_
x
= torch.tensor([0, 0])
batch_
y
= torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)
assign_index = knn(x, y, 2, batch_x, batch_y)
"""
"""
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
if
x
.
is_cuda
:
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
ptr_x
=
torch
.
tensor
([
0
,
x
.
size
(
0
)],
device
=
x
.
device
)
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
)
else
:
if
batch_x
is
None
:
batch_x
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
None
:
batch_y
=
y
.
new_zeros
(
y
.
size
(
0
),
dtype
=
torch
.
long
)
assert
x
.
dim
()
==
2
and
batch_x
.
dim
()
==
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
assert
y
.
dim
()
==
2
and
batch_y
.
dim
()
==
1
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
if
cosine
:
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
raise
NotImplementedError
(
'`cosine` argument not supported on CPU'
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:]
)
# Translate and rescale x and y to [0, 1].
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
min_xy
=
min
(
x
.
min
().
item
(),
y
.
min
().
item
())
if
batch_y
is
not
None
:
x
,
y
=
x
-
min_xy
,
y
-
min_xy
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
max_xy
=
max
(
x
.
max
().
item
(),
y
.
max
().
item
())
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
x
.
div_
(
max_xy
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
y
.
div_
(
max_xy
)
# Concat batch/features to ensure no cross-links between examples.
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
x
=
torch
.
cat
([
x
,
2
*
x
.
size
(
1
)
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
-
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
y
=
torch
.
cat
([
y
,
2
*
y
.
size
(
1
)
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
-
1
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
dist
,
col
=
tree
.
query
(
y
.
detach
().
cpu
(),
k
=
k
,
num_workers
)
distance_upper_bound
=
x
.
size
(
1
))
dist
=
torch
.
from_numpy
(
dist
).
to
(
x
.
dtype
)
col
=
torch
.
from_numpy
(
col
).
to
(
torch
.
long
)
row
=
torch
.
arange
(
col
.
size
(
0
),
dtype
=
torch
.
long
)
row
=
row
.
view
(
-
1
,
1
).
repeat
(
1
,
k
)
mask
=
~
torch
.
isinf
(
dist
).
view
(
-
1
)
row
,
col
=
row
.
view
(
-
1
)[
mask
],
col
.
view
(
-
1
)[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
@
torch
.
jit
.
script
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
cosine
:
bool
=
False
)
->
torch
.
Tensor
:
cosine
:
bool
=
False
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to the nearest :obj:`k` points.
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
Args:
...
@@ -122,15 +88,19 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -122,15 +88,19 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
k (int): The number of neighbors.
k (int): The number of neighbors.
batch (LongTensor, optional): Batch vector
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when us
ing
in combination
flow (string, optional): The flow direction when us
ed
in combination
with message passing (:obj:`"source_to_target"` or
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
cosine (boolean, optional): If :obj:`True`, will use the
c
osine
cosine (boolean, optional): If :obj:`True`, will use the
C
osine
distance instead of
e
uclidean distance to find nearest neighbors.
distance instead of
E
uclidean distance to find nearest neighbors.
(default: :obj:`False`)
(default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -145,9 +115,16 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -145,9 +115,16 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
"""
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
=
cosine
)
edge_index
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
num_workers
)
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
else
:
row
,
col
=
edge_index
[
0
],
edge_index
[
1
]
if
not
loop
:
if
not
loop
:
mask
=
row
!=
col
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
row
[
mask
],
col
[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
torch_cluster/radius.py
View file @
1abbba60
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
scipy.spatial
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
sample
(
col
:
torch
.
Tensor
,
count
:
int
)
->
torch
.
Tensor
:
if
col
.
size
(
0
)
>
count
:
col
=
col
[
torch
.
randperm
(
col
.
size
(
0
),
dtype
=
torch
.
long
)][:
count
]
return
col
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
max_num_neighbo
rs
:
int
=
32
)
->
torch
.
Tensor
:
num_worke
rs
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
distance :obj:`r`.
...
@@ -26,12 +19,17 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -26,12 +19,17 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
r (float): The radius.
r (float): The radius.
batch_x (LongTensor, optional): Batch vector
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_y` needs to be sorted.
(default: :obj:`None`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
return for each element in :obj:`y`. (default: :obj:`32`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
.. code-block:: python
.. code-block:: python
...
@@ -47,63 +45,39 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -47,63 +45,39 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
if
x
.
is_cuda
:
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
else
:
ptr_x
=
torch
.
tensor
([
0
,
x
.
size
(
0
)],
device
=
x
.
device
)
if
batch_y
is
not
None
:
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
)
else
:
if
batch_x
is
None
:
batch_x
=
x
.
new_zeros
(
x
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
None
:
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
batch_y
=
y
.
new_zeros
(
y
.
size
(
0
),
dtype
=
torch
.
long
)
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
assert
x
.
dim
()
==
2
and
batch_x
.
dim
()
==
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
assert
y
.
dim
()
==
2
and
batch_y
.
dim
()
==
1
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
x
=
torch
.
cat
([
x
,
2
*
r
*
batch_x
.
view
(
-
1
,
1
).
to
(
x
.
dtype
)],
dim
=-
1
)
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
y
=
torch
.
c
at
([
y
,
2
*
r
*
batch_y
.
view
(
-
1
,
1
).
to
(
y
.
dtype
)],
dim
=-
1
)
torch
.
c
umsum
(
deg
,
0
,
out
=
ptr_y
[
1
:]
)
tree
=
scipy
.
spatial
.
cKDTree
(
x
.
detach
().
numpy
())
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
col
=
tree
.
query_ball_point
(
y
.
detach
().
numpy
(),
r
)
max_num_neighbors
,
num_workers
)
col
=
[
torch
.
tensor
(
c
,
dtype
=
torch
.
long
)
for
c
in
col
]
col
=
[
sample
(
c
,
max_num_neighbors
)
for
c
in
col
]
row
=
[
torch
.
full_like
(
c
,
i
)
for
i
,
c
in
enumerate
(
col
)]
row
,
col
=
torch
.
cat
(
row
,
dim
=
0
),
torch
.
cat
(
col
,
dim
=
0
)
mask
=
col
<
int
(
tree
.
n
)
return
torch
.
stack
([
row
[
mask
],
col
[
mask
]],
dim
=
0
)
@
torch
.
jit
.
script
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
flow
:
str
=
'source_to_target'
)
->
torch
.
Tensor
:
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to all points within a given distance.
r
"""Computes graph edges to all points within a given distance.
Args:
Args:
...
@@ -112,14 +86,18 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -112,14 +86,18 @@ def radius_graph(x: torch.Tensor, r: float,
r (float): The radius.
r (float): The radius.
batch (LongTensor, optional): Batch vector
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element
in :obj:`y`
. (default: :obj:`32`)
return for each element. (default: :obj:`32`)
flow (string, optional): The flow direction when us
ing
in combination
flow (string, optional): The flow direction when us
ed
in combination
with message passing (:obj:`"source_to_target"` or
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -134,10 +112,16 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -134,10 +112,16 @@ def radius_graph(x: torch.Tensor, r: float,
"""
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
edge_index
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
)
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
num_workers
)
if
flow
==
'source_to_target'
:
row
,
col
=
edge_index
[
1
],
edge_index
[
0
]
else
:
row
,
col
=
edge_index
[
0
],
edge_index
[
1
]
if
not
loop
:
if
not
loop
:
mask
=
row
!=
col
mask
=
row
!=
col
row
,
col
=
row
[
mask
],
col
[
mask
]
row
,
col
=
row
[
mask
],
col
[
mask
]
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
torch
.
stack
([
row
,
col
],
dim
=
0
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment