"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "90eac14f720cf66ca1e28f1cc4af32df44806bc7"
Unverified Commit 2f983abe authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix knn graph bugs (#1716)

parent 8b539079
...@@ -93,7 +93,7 @@ def knn_graph(x, k): ...@@ -93,7 +93,7 @@ def knn_graph(x, k):
src = F.reshape(src, (-1,)) src = F.reshape(src, (-1,))
adj = sparse.csr_matrix( adj = sparse.csr_matrix(
(F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src))), (F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src))),
shape=(n_points, n_points)) shape=(n_samples * n_points, n_samples * n_points))
g = DGLGraph(adj, readonly=True) g = DGLGraph(adj, readonly=True)
return g return g
...@@ -129,7 +129,7 @@ def segmented_knn_graph(x, k, segs): ...@@ -129,7 +129,7 @@ def segmented_knn_graph(x, k, segs):
h_list = F.split(x, segs, 0) h_list = F.split(x, segs, 0)
dst = [ dst = [
F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False) + F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False) +
offset[i] int(offset[i])
for i, h_g in enumerate(h_list)] for i, h_g in enumerate(h_list)]
dst = F.cat(dst, 0) dst = F.cat(dst, 0)
src = F.arange(0, n_total_points).unsqueeze(1).expand(n_total_points, k) src = F.arange(0, n_total_points).unsqueeze(1).expand(n_total_points, k)
......
import torch as th import torch as th
import dgl.nn
from dgl.geometry.pytorch import FarthestPointSampler from dgl.geometry.pytorch import FarthestPointSampler
import backend as F import backend as F
import numpy as np import numpy as np
...@@ -17,5 +18,31 @@ def test_fps(): ...@@ -17,5 +18,31 @@ def test_fps():
assert res.shape[1] == sample_points assert res.shape[1] == sample_points
assert res.sum() > 0 assert res.sum() > 0
def test_knn():
x = th.randn(8, 3)
kg = dgl.nn.KNNGraph(3)
d = th.cdist(x, x)
def check_knn(g, x, start, end):
for v in range(start, end):
src, _ = g.in_edges(v)
src = set(src.numpy())
i = v - start
src_ans = set(th.topk(d[start:end, start:end][i], 3, largest=False)[1].numpy() + start)
assert src == src_ans
g = kg(x)
check_knn(g, x, 0, 8)
g = kg(x.view(2, 4, 3))
check_knn(g, x, 0, 4)
check_knn(g, x, 4, 8)
kg = dgl.nn.SegmentedKNNGraph(3)
g = kg(x, [3, 5])
check_knn(g, x, 0, 3)
check_knn(g, x, 3, 8)
if __name__ == '__main__': if __name__ == '__main__':
test_fps() test_fps()
test_knn()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment