Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
10049daf
Commit
10049daf
authored
Nov 05, 2021
by
rusty1s
Browse files
fix knn/radius for batches with zero-point examples
parent
ae639fd6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
0 deletions
+19
-0
csrc/cpu/knn_cpu.cpp
csrc/cpu/knn_cpu.cpp
+3
-0
csrc/cpu/radius_cpu.cpp
csrc/cpu/radius_cpu.cpp
+3
-0
test/test_knn.py
test/test_knn.py
+6
-0
test/test_radius.py
test/test_radius.py
+7
-0
No files found.
csrc/cpu/knn_cpu.cpp
View file @
10049daf
...
@@ -67,6 +67,9 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
...
@@ -67,6 +67,9 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
auto
x_start
=
ptr_x_data
[
b
],
x_end
=
ptr_x_data
[
b
+
1
];
auto
x_start
=
ptr_x_data
[
b
],
x_end
=
ptr_x_data
[
b
+
1
];
auto
y_start
=
ptr_y_data
[
b
],
y_end
=
ptr_y_data
[
b
+
1
];
auto
y_start
=
ptr_y_data
[
b
],
y_end
=
ptr_y_data
[
b
+
1
];
if
(
x_start
==
x_end
||
y_start
==
y_end
)
continue
;
vec_t
pts
(
x_end
-
x_start
);
vec_t
pts
(
x_end
-
x_start
);
for
(
int64_t
i
=
0
;
i
<
x_end
-
x_start
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
x_end
-
x_start
;
i
++
)
{
pts
[
i
].
resize
(
x
.
size
(
1
));
pts
[
i
].
resize
(
x
.
size
(
1
));
...
...
csrc/cpu/radius_cpu.cpp
View file @
10049daf
...
@@ -70,6 +70,9 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
...
@@ -70,6 +70,9 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
auto
x_start
=
ptr_x_data
[
b
],
x_end
=
ptr_x_data
[
b
+
1
];
auto
x_start
=
ptr_x_data
[
b
],
x_end
=
ptr_x_data
[
b
+
1
];
auto
y_start
=
ptr_y_data
[
b
],
y_end
=
ptr_y_data
[
b
+
1
];
auto
y_start
=
ptr_y_data
[
b
],
y_end
=
ptr_y_data
[
b
+
1
];
if
(
x_start
==
x_end
||
y_start
==
y_end
)
continue
;
vec_t
pts
(
x_end
-
x_start
);
vec_t
pts
(
x_end
-
x_start
);
for
(
int64_t
i
=
0
;
i
<
x_end
-
x_start
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
x_end
-
x_start
;
i
++
)
{
pts
[
i
].
resize
(
x
.
size
(
1
));
pts
[
i
].
resize
(
x
.
size
(
1
));
...
...
test/test_knn.py
View file @
10049daf
...
@@ -42,6 +42,12 @@ def test_knn(dtype, device):
...
@@ -42,6 +42,12 @@ def test_knn(dtype, device):
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
,
cosine
=
True
)
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
,
cosine
=
True
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
# Skipping a batch
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
2
,
2
,
2
,
2
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
2
],
torch
.
long
,
device
)
edge_index
=
knn
(
x
,
y
,
2
,
batch_x
,
batch_y
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
2
),
(
0
,
3
),
(
1
,
4
),
(
1
,
5
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_knn_graph
(
dtype
,
device
):
def
test_knn_graph
(
dtype
,
device
):
...
...
test/test_radius.py
View file @
10049daf
...
@@ -40,6 +40,13 @@ def test_radius(dtype, device):
...
@@ -40,6 +40,13 @@ def test_radius(dtype, device):
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
(
1
,
6
)])
(
1
,
6
)])
# Skipping a batch
batch_x
=
tensor
([
0
,
0
,
0
,
0
,
2
,
2
,
2
,
2
],
torch
.
long
,
device
)
batch_y
=
tensor
([
0
,
2
],
torch
.
long
,
device
)
edge_index
=
radius
(
x
,
y
,
2
,
batch_x
,
batch_y
,
max_num_neighbors
=
4
)
assert
to_set
(
edge_index
)
==
set
([(
0
,
0
),
(
0
,
1
),
(
0
,
2
),
(
0
,
3
),
(
1
,
5
),
(
1
,
6
)])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_radius_graph
(
dtype
,
device
):
def
test_radius_graph
(
dtype
,
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment