Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
b8166f31
Commit
b8166f31
authored
Jun 22, 2020
by
rusty1s
Browse files
linting and interface changes
parent
cd7dbf25
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
517 additions
and
555 deletions
+517
-555
csrc/cpu/utils/cloud.h
csrc/cpu/utils/cloud.h
+55
-62
csrc/cpu/utils/neighbors.cpp
csrc/cpu/utils/neighbors.cpp
+378
-373
torch_cluster/knn.py
torch_cluster/knn.py
+39
-56
torch_cluster/radius.py
torch_cluster/radius.py
+45
-64
No files found.
csrc/cpu/utils/cloud.h
100755 → 100644
View file @
b8166f31
// Author: Peiyuan Liao (alexander_liao@outlook.com)
#pragma once
//
# pragma once
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <algorithm>
#include <cmath>
#include <cmath>
#include <
vector
>
#include <
iomanip
>
#include <
unordered_map
>
#include <
iostream
>
#include <map>
#include <map>
#include <algorithm>
#include <numeric>
#include <numeric>
#include <
iostream
>
#include <
unordered_map
>
#include <
iomanip
>
#include <
vector
>
#include <time.h>
#include <time.h>
template
<
typename
scalar_t
>
struct
PointCloud
{
template
<
typename
scalar_t
>
std
::
vector
<
std
::
vector
<
scalar_t
>
*>
pts
;
struct
PointCloud
{
void
set
(
std
::
vector
<
scalar_t
>
new_pts
,
int
dim
)
{
std
::
vector
<
std
::
vector
<
scalar_t
>*>
pts
;
std
::
vector
<
std
::
vector
<
scalar_t
>
*>
temp
(
new_pts
.
size
()
/
dim
);
void
set
(
std
::
vector
<
scalar_t
>
new_pts
,
int
dim
){
for
(
size_t
i
=
0
;
i
<
new_pts
.
size
();
i
++
)
{
if
(
i
%
dim
==
0
)
{
std
::
vector
<
std
::
vector
<
scalar_t
>*>
temp
(
new_pts
.
size
()
/
dim
);
std
::
vector
<
scalar_t
>
*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
for
(
size_t
i
=
0
;
i
<
new_pts
.
size
();
i
++
){
if
(
i
%
dim
==
0
){
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
std
::
vector
<
scalar_t
>*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
(
*
point
)[
j
]
=
new_pts
[
i
+
j
];
}
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
temp
[
i
/
dim
]
=
point
;
(
*
point
)[
j
]
=
new_pts
[
i
+
j
];
}
}
}
temp
[
i
/
dim
]
=
point
;
}
pts
=
temp
;
}
}
void
set_batch
(
std
::
vector
<
scalar_t
>
new_pts
,
size_t
begin
,
long
size
,
pts
=
temp
;
int
dim
)
{
}
std
::
vector
<
std
::
vector
<
scalar_t
>
*>
temp
(
size
);
void
set_batch
(
std
::
vector
<
scalar_t
>
new_pts
,
size_t
begin
,
long
size
,
int
dim
){
for
(
size_t
i
=
0
;
i
<
(
size_t
)
size
;
i
++
)
{
std
::
vector
<
std
::
vector
<
scalar_t
>*>
temp
(
size
);
std
::
vector
<
scalar_t
>
*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
for
(
size_t
i
=
0
;
i
<
(
size_t
)
size
;
i
++
){
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
std
::
vector
<
scalar_t
>*
point
=
new
std
::
vector
<
scalar_t
>
(
dim
);
(
*
point
)[
j
]
=
new_pts
[
dim
*
(
begin
+
i
)
+
j
];
for
(
size_t
j
=
0
;
j
<
(
size_t
)
dim
;
j
++
)
{
}
(
*
point
)[
j
]
=
new_pts
[
dim
*
(
begin
+
i
)
+
j
];
}
temp
[
i
]
=
point
;
}
temp
[
i
]
=
point
;
pts
=
temp
;
}
}
pts
=
temp
;
// Must return the number of data points.
}
inline
size_t
kdtree_get_point_count
()
const
{
return
pts
.
size
();
}
// Must return the number of data points
// Returns the dim'th component of the idx'th point in the class:
inline
size_t
kdtree_get_point_count
()
const
{
return
pts
.
size
();
}
inline
scalar_t
kdtree_get_pt
(
const
size_t
idx
,
const
size_t
dim
)
const
{
return
(
*
pts
[
idx
])[
dim
];
// Returns the dim'th component of the idx'th point in the class:
}
inline
scalar_t
kdtree_get_pt
(
const
size_t
idx
,
const
size_t
dim
)
const
{
// Optional bounding-box computation: return false to default to a standard
return
(
*
pts
[
idx
])[
dim
];
// bbox computation loop.
}
// Return true if the BBOX was already computed by the class and returned in
// "bb" so it can be avoided to redo it again. Look at bb.size() to find out
// Optional bounding-box computation: return false to default to a standard bbox computation loop.
// the expected dimensionality (e.g. 2 or 3 for point clouds)
// Return true if the BBOX was already computed by the class and returned in "bb" so it can be avoided to redo it again.
template
<
class
BBOX
>
bool
kdtree_get_bbox
(
BBOX
&
/* bb */
)
const
{
// Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds)
return
false
;
template
<
class
BBOX
>
}
bool
kdtree_get_bbox
(
BBOX
&
/* bb */
)
const
{
return
false
;
}
};
};
csrc/cpu/utils/neighbors.cpp
100755 → 100644
View file @
b8166f31
This diff is collapsed.
Click to expand it.
torch_cluster/knn.py
View file @
b8166f31
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
numpy
as
np
@
torch
.
jit
.
script
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
def
knn
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
k
:
int
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
cosine
:
bool
=
False
,
cosine
:
bool
=
False
,
n_thread
s
:
int
=
1
)
->
torch
.
Tensor
:
num_worker
s
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
r
"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
:obj:`x`.
...
@@ -19,13 +19,18 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -19,13 +19,18 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
k (int): The number of neighbors.
k (int): The number of neighbors.
batch_x (LongTensor, optional): Batch vector
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_y` needs to be sorted.
cosine (boolean, optional): If :obj:`True`, will use the cosine
(default: :obj:`None`)
distance instead of euclidean distance to find nearest neighbors.
cosine (boolean, optional): If :obj:`True`, will use the Cosine
(default: :obj:`False`)
distance instead of the Euclidean distance to find nearest
neighbors. (default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -44,62 +49,36 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -44,62 +49,36 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
def
is_sorted
(
x
):
if
batch_x
is
not
None
:
return
(
np
.
diff
(
x
.
detach
().
cpu
())
>=
0
).
all
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
if
x
.
is_cuda
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
is_sorted
(
batch_x
)
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
ptr_x
=
torch
.
tensor
([
0
,
x
.
size
(
0
)],
device
=
x
.
device
)
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
is_sorted
(
batch_y
)
batch_size
=
int
(
batch_y
.
max
())
+
1
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
n_threads
)
else
:
else
:
assert
x
.
dim
()
==
2
ptr_y
=
torch
.
tensor
([
0
,
y
.
size
(
0
)],
device
=
y
.
device
)
if
batch_x
is
not
None
:
assert
batch_x
.
dim
()
==
1
assert
is_sorted
(
batch_x
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
y
.
dim
()
==
2
if
batch_y
is
not
None
:
assert
batch_y
.
dim
()
==
1
assert
is_sorted
(
batch_y
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
if
cosine
:
raise
NotImplementedError
(
'`cosine` argument not supported on CPU'
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
batch_x
,
batch_y
,
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
k
,
cosine
,
n_thread
s
)
num_worker
s
)
@
torch
.
jit
.
script
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
def
knn_graph
(
x
:
torch
.
Tensor
,
k
:
int
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
loop
:
bool
=
False
,
flow
:
str
=
'source_to_target'
,
cosine
:
bool
=
False
,
n
_thread
s
:
int
=
1
)
->
torch
.
Tensor
:
cosine
:
bool
=
False
,
n
um_worker
s
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to the nearest :obj:`k` points.
r
"""Computes graph edges to the nearest :obj:`k` points.
Args:
Args:
...
@@ -108,7 +87,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -108,7 +87,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
k (int): The number of neighbors.
k (int): The number of neighbors.
batch (LongTensor, optional): Batch vector
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when using in combination
flow (string, optional): The flow direction when using in combination
...
@@ -117,6 +97,9 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -117,6 +97,9 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
cosine (boolean, optional): If :obj:`True`, will use the cosine
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
(default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -131,8 +114,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
...
@@ -131,8 +114,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
"""
"""
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
row
,
col
=
knn
(
x
,
x
,
k
if
loop
else
k
+
1
,
batch
,
batch
,
cosine
,
cosine
=
cosine
,
n_threads
=
n_thread
s
)
num_worker
s
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
if
not
loop
:
if
not
loop
:
mask
=
row
!=
col
mask
=
row
!=
col
...
...
torch_cluster/radius.py
View file @
b8166f31
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
numpy
as
np
@
torch
.
jit
.
script
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
def
radius
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
r
:
float
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_x
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
batch_y
:
Optional
[
torch
.
Tensor
]
=
None
,
max_num_neighbors
:
int
=
32
,
max_num_neighbors
:
int
=
32
,
n_thread
s
:
int
=
1
)
->
torch
.
Tensor
:
num_worker
s
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
r
"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
distance :obj:`r`.
...
@@ -16,17 +17,19 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -16,17 +17,19 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y (Tensor): Node feature matrix
y (Tensor): Node feature matrix
:math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
:math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
r (float): The radius.
r (float): The radius.
batch_x (LongTensor, optional): Batch vector
(must be sorted)
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_x` needs to be sorted.
batch_y (LongTensor, optional): Batch vector (must be sorted)
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch_y` needs to be sorted.
(default: :obj:`None`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`)
return for each element in :obj:`y`. (default: :obj:`32`)
n
_thread
s (int):
n
umber of
threads when the input is on CPU. Note
n
um_worker
s (int):
N
umber of
workers to use for computation. Has no
that this has no effect when
batch_x or batch_y is not
None, or
effect in case :obj:`
batch_x
`
or
:obj:`
batch_y
`
is not
x is on
GPU. (default: :obj:`1`)
:obj:`None`, or the input lies on the
GPU. (default: :obj:`1`)
.. code-block:: python
.. code-block:: python
...
@@ -43,71 +46,49 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -43,71 +46,49 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
view
(
-
1
,
1
)
if
x
.
dim
()
==
1
else
x
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
def
is_sorted
(
x
):
if
batch_x
is
not
None
:
return
(
np
.
diff
(
x
.
detach
().
cpu
())
>=
0
).
all
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
if
x
.
is_cuda
:
if
batch_x
is
not
None
:
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
assert
is_sorted
(
batch_x
)
batch_size
=
int
(
batch_x
.
max
())
+
1
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
else
:
ptr_x
=
None
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
is_sorted
(
batch_y
)
batch_size
=
int
(
batch_y
.
max
())
+
1
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
None
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
,
n_threads
)
else
:
else
:
assert
x
.
dim
()
==
2
ptr_x
=
None
if
batch_x
is
not
None
:
assert
batch_x
.
dim
()
==
1
assert
is_sorted
(
batch_x
)
assert
x
.
size
(
0
)
==
batch_x
.
size
(
0
)
assert
y
.
dim
()
==
2
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
batch_y
.
dim
()
==
1
batch_size
=
int
(
batch_y
.
max
())
+
1
assert
is_sorted
(
batch_y
)
assert
y
.
size
(
0
)
==
batch_y
.
size
(
0
)
assert
x
.
size
(
1
)
==
y
.
size
(
1
)
result
=
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
batch_x
,
batch_y
,
r
,
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
max_num_neighbors
,
n_threads
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
else
:
ptr_y
=
None
return
result
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
,
num_workers
)
@
torch
.
jit
.
script
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
def
radius_graph
(
x
:
torch
.
Tensor
,
r
:
float
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
batch
:
Optional
[
torch
.
Tensor
]
=
None
,
loop
:
bool
=
False
,
max_num_neighbors
:
int
=
32
,
max_num_neighbors
:
int
=
32
,
flow
:
str
=
'source_to_target'
,
flow
:
str
=
'source_to_target'
,
num_workers
:
int
=
1
)
->
torch
.
Tensor
:
n_threads
:
int
=
1
)
->
torch
.
Tensor
:
r
"""Computes graph edges to all points within a given distance.
r
"""Computes graph edges to all points within a given distance.
Args:
Args:
x (Tensor): Node feature matrix
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
r (float): The radius.
r (float): The radius.
batch (LongTensor, optional): Batch vector
(must be sorted)
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
max_num_neighbors (int, optional): The maximum number of neighbors to
...
@@ -115,9 +96,9 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -115,9 +96,9 @@ def radius_graph(x: torch.Tensor, r: float,
flow (string, optional): The flow direction when using in combination
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
n
_thread
s (int):
n
umber of
threads when the input is on CPU. Note
n
um_worker
s (int):
N
umber of
workers to use for computation. Has no
that this has no effect when batch_x or
batch
_y
is not None, or
effect in case :obj:`
batch
`
is not
:obj:`
None
`
, or
the input lies
x is on
GPU. (default: :obj:`1`)
on the
GPU. (default: :obj:`1`)
:rtype: :class:`LongTensor`
:rtype: :class:`LongTensor`
...
@@ -134,7 +115,7 @@ def radius_graph(x: torch.Tensor, r: float,
...
@@ -134,7 +115,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
assert
flow
in
[
'source_to_target'
,
'target_to_source'
]
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
row
,
col
=
radius
(
x
,
x
,
r
,
batch
,
batch
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
max_num_neighbors
if
loop
else
max_num_neighbors
+
1
,
n
_thread
s
)
n
um_worker
s
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
row
,
col
=
(
col
,
row
)
if
flow
==
'source_to_target'
else
(
row
,
col
)
if
not
loop
:
if
not
loop
:
mask
=
row
!=
col
mask
=
row
!=
col
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment