Unverified Commit 241b60d4 authored by milesial's avatar milesial Committed by GitHub
Browse files

[Feature][Performance] Accelerate kNN on GPU (#2868)



* topk kNN accelerated on GPU

* Fix src offset

* Fix trailing whitespace

* Update segmented kNN

* Fix src / dst mixup

* Update kNN docstrings

* Fixed lint
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent 975eb8fc
...@@ -100,7 +100,8 @@ def knn_graph(x, k, algorithm='topk'): ...@@ -100,7 +100,8 @@ def knn_graph(x, k, algorithm='topk'):
DGLGraph DGLGraph
The constructred graph. The node IDs are in the same order as :attr:`x`. The constructred graph. The node IDs are in the same order as :attr:`x`.
The returned graph is on CPU, regardless of the context of input :attr:`x`. If using the 'topk' algorithm, the returned graph is on the same device as input :attr:`x`.
Else, the returned graph is on CPU, regardless of the context of the input :attr:`x`.
Examples Examples
-------- --------
...@@ -170,21 +171,17 @@ def _knn_graph_topk(x, k): ...@@ -170,21 +171,17 @@ def _knn_graph_topk(x, k):
x = F.unsqueeze(x, 0) x = F.unsqueeze(x, 0)
n_samples, n_points, _ = F.shape(x) n_samples, n_points, _ = F.shape(x)
ctx = F.context(x)
dist = pairwise_squared_distance(x) dist = pairwise_squared_distance(x)
k_indices = F.argtopk(dist, k, 2, descending=False) k_indices = F.argtopk(dist, k, 2, descending=False)
dst = F.copy_to(k_indices, F.cpu()) # index offset for each sample
offset = F.arange(0, n_samples, ctx=ctx) * n_points
src = F.zeros_like(dst) + F.reshape(F.arange(0, n_points), (1, -1, 1)) offset = F.unsqueeze(offset, 1)
src = F.reshape(k_indices, (n_samples, n_points * k))
per_sample_offset = F.reshape(F.arange(0, n_samples) * n_points, (-1, 1, 1)) src = F.unsqueeze(src, 0) + offset
dst += per_sample_offset dst = F.repeat(F.arange(0, n_points, ctx=ctx), k, dim=0)
src += per_sample_offset dst = F.unsqueeze(dst, 0) + offset
dst = F.reshape(dst, (-1,)) return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
src = F.reshape(src, (-1,))
adj = sparse.csr_matrix(
(F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src))),
shape=(n_samples * n_points, n_samples * n_points))
return convert.from_scipy(adj)
#pylint: disable=invalid-name #pylint: disable=invalid-name
def segmented_knn_graph(x, k, segs, algorithm='topk'): def segmented_knn_graph(x, k, segs, algorithm='topk'):
...@@ -223,7 +220,8 @@ def segmented_knn_graph(x, k, segs, algorithm='topk'): ...@@ -223,7 +220,8 @@ def segmented_knn_graph(x, k, segs, algorithm='topk'):
DGLGraph DGLGraph
The graph. The node IDs are in the same order as :attr:`x`. The graph. The node IDs are in the same order as :attr:`x`.
The returned graph is on CPU, regardless of the context of input :attr:`x`. If using the 'topk' algorithm, the returned graph is on the same device as input :attr:`x`.
Else, the returned graph is on CPU, regardless of the context of the input :attr:`x`.
Examples Examples
-------- --------
...@@ -277,18 +275,14 @@ def _segmented_knn_graph_topk(x, k, segs): ...@@ -277,18 +275,14 @@ def _segmented_knn_graph_topk(x, k, segs):
offset = np.insert(np.cumsum(segs), 0, 0) offset = np.insert(np.cumsum(segs), 0, 0)
h_list = F.split(x, segs, 0) h_list = F.split(x, segs, 0)
dst = [ src = [
F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False) + F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False) +
int(offset[i]) int(offset[i])
for i, h_g in enumerate(h_list)] for i, h_g in enumerate(h_list)]
dst = F.cat(dst, 0) src = F.cat(src, 0)
src = F.arange(0, n_total_points).unsqueeze(1).expand(n_total_points, k) ctx = F.context(x)
dst = F.repeat(F.arange(0, n_total_points, ctx=ctx), k, dim=0)
dst = F.reshape(dst, (-1,)) return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
src = F.reshape(src, (-1,))
adj = sparse.csr_matrix((F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src))))
return convert.from_scipy(adj)
def knn(x, x_segs, y, y_segs, k, algorithm='kd-tree', dist='euclidean'): def knn(x, x_segs, y, y_segs, k, algorithm='kd-tree', dist='euclidean'):
r"""For each element in each segment in :attr:`y`, find :attr:`k` nearest r"""For each element in each segment in :attr:`y`, find :attr:`k` nearest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment