Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2f983abe
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "90eac14f720cf66ca1e28f1cc4af32df44806bc7"
Unverified
Commit
2f983abe
authored
Jun 28, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Jun 28, 2020
Browse files
fix knn graph bugs (#1716)
parent
8b539079
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
2 deletions
+29
-2
python/dgl/transform.py
python/dgl/transform.py
+2
-2
tests/pytorch/test_geometry.py
tests/pytorch/test_geometry.py
+27
-0
No files found.
python/dgl/transform.py
View file @
2f983abe
...
@@ -93,7 +93,7 @@ def knn_graph(x, k):
...
@@ -93,7 +93,7 @@ def knn_graph(x, k):
src
=
F
.
reshape
(
src
,
(
-
1
,))
src
=
F
.
reshape
(
src
,
(
-
1
,))
adj
=
sparse
.
csr_matrix
(
adj
=
sparse
.
csr_matrix
(
(
F
.
asnumpy
(
F
.
zeros_like
(
dst
)
+
1
),
(
F
.
asnumpy
(
dst
),
F
.
asnumpy
(
src
))),
(
F
.
asnumpy
(
F
.
zeros_like
(
dst
)
+
1
),
(
F
.
asnumpy
(
dst
),
F
.
asnumpy
(
src
))),
shape
=
(
n_
points
,
n_points
))
shape
=
(
n_
samples
*
n_points
,
n_samples
*
n_points
))
g
=
DGLGraph
(
adj
,
readonly
=
True
)
g
=
DGLGraph
(
adj
,
readonly
=
True
)
return
g
return
g
...
@@ -129,7 +129,7 @@ def segmented_knn_graph(x, k, segs):
...
@@ -129,7 +129,7 @@ def segmented_knn_graph(x, k, segs):
h_list
=
F
.
split
(
x
,
segs
,
0
)
h_list
=
F
.
split
(
x
,
segs
,
0
)
dst
=
[
dst
=
[
F
.
argtopk
(
pairwise_squared_distance
(
h_g
),
k
,
1
,
descending
=
False
)
+
F
.
argtopk
(
pairwise_squared_distance
(
h_g
),
k
,
1
,
descending
=
False
)
+
offset
[
i
]
int
(
offset
[
i
]
)
for
i
,
h_g
in
enumerate
(
h_list
)]
for
i
,
h_g
in
enumerate
(
h_list
)]
dst
=
F
.
cat
(
dst
,
0
)
dst
=
F
.
cat
(
dst
,
0
)
src
=
F
.
arange
(
0
,
n_total_points
).
unsqueeze
(
1
).
expand
(
n_total_points
,
k
)
src
=
F
.
arange
(
0
,
n_total_points
).
unsqueeze
(
1
).
expand
(
n_total_points
,
k
)
...
...
tests/pytorch/test_geometry.py
View file @
2f983abe
import
torch
as
th
import
torch
as
th
import
dgl.nn
from
dgl.geometry.pytorch
import
FarthestPointSampler
from
dgl.geometry.pytorch
import
FarthestPointSampler
import
backend
as
F
import
backend
as
F
import
numpy
as
np
import
numpy
as
np
...
@@ -17,5 +18,31 @@ def test_fps():
...
@@ -17,5 +18,31 @@ def test_fps():
assert
res
.
shape
[
1
]
==
sample_points
assert
res
.
shape
[
1
]
==
sample_points
assert
res
.
sum
()
>
0
assert
res
.
sum
()
>
0
def
test_knn
():
x
=
th
.
randn
(
8
,
3
)
kg
=
dgl
.
nn
.
KNNGraph
(
3
)
d
=
th
.
cdist
(
x
,
x
)
def
check_knn
(
g
,
x
,
start
,
end
):
for
v
in
range
(
start
,
end
):
src
,
_
=
g
.
in_edges
(
v
)
src
=
set
(
src
.
numpy
())
i
=
v
-
start
src_ans
=
set
(
th
.
topk
(
d
[
start
:
end
,
start
:
end
][
i
],
3
,
largest
=
False
)[
1
].
numpy
()
+
start
)
assert
src
==
src_ans
g
=
kg
(
x
)
check_knn
(
g
,
x
,
0
,
8
)
g
=
kg
(
x
.
view
(
2
,
4
,
3
))
check_knn
(
g
,
x
,
0
,
4
)
check_knn
(
g
,
x
,
4
,
8
)
kg
=
dgl
.
nn
.
SegmentedKNNGraph
(
3
)
g
=
kg
(
x
,
[
3
,
5
])
check_knn
(
g
,
x
,
0
,
3
)
check_knn
(
g
,
x
,
3
,
8
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_fps
()
test_fps
()
test_knn
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment