Unverified Commit 64d0f3f3 authored by Tianqi Zhang (张天启)'s avatar Tianqi Zhang (张天启) Committed by GitHub
Browse files

[Feature] Add NN-descent support for the KNN graph function in dgl (#2941)



* add bruteforce impl

* add nn descent implementation

* change doc-string

* remove redundant func

* use local rng for cuda

* fix lint

* fix lint

* fix bug

* fix bug

* wrap nndescent_knn_graph into knn

* fix lint

* change function names

* add comment for dist funcs

* let the compiler do the unrolling

* use better blocksize setting

* remove redundant line

* check the return of the cub calls
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent a303f078
......@@ -104,6 +104,11 @@ class KNNGraph(nn.Module):
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......@@ -212,6 +217,11 @@ class SegmentedKNNGraph(nn.Module):
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......
......@@ -117,6 +117,11 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......@@ -182,7 +187,7 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
x_seg = x_size[0] * [x_size[1]]
else:
x_seg = [F.shape(x)[0]]
out = knn(x, x_seg, x, x_seg, k, algorithm=algorithm, dist=dist)
out = knn(k, x, x_seg, algorithm=algorithm, dist=dist)
row, col = out[1], out[0]
return convert.graph((row, col))
......@@ -287,6 +292,11 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......@@ -338,7 +348,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
if algorithm == 'bruteforce-blas':
return _segmented_knn_graph_blas(x, k, segs, dist=dist)
else:
out = knn(x, segs, x, segs, k, algorithm=algorithm, dist=dist)
out = knn(k, x, segs, algorithm=algorithm, dist=dist)
row, col = out[1], out[0]
return convert.graph((row, col))
......@@ -390,9 +400,92 @@ def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'):
dst = F.repeat(F.arange(0, n_total_points, ctx=ctx), k, dim=0)
return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
def _nndescent_knn_graph(x, k, segs, num_iters=None, max_candidates=None,
delta=0.001, sample_rate=0.5, dist='euclidean'):
r"""Construct multiple graphs from multiple sets of points according to
**approximate** k-nearest-neighbor using NN-descent algorithm from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_.
Parameters
----------
x : Tensor
Coordinates/features of points. Must be 2D. It can be either on CPU or GPU.
k : int
The number of nearest neighbors per node.
segs : list[int]
Number of points in each point set. The numbers in :attr:`segs`
must sum up to the number of rows in :attr:`x`.
num_iters : int, optional
The maximum number of NN-descent iterations to perform. A value will be
chosen based on the size of input by default.
(Default: None)
max_candidates : int, optional
The maximum number of candidates to be considered during one iteration.
Larger values will provide more accurate search results later, but
potentially at non-negligible computation cost. A value will be chosen
based on the number of neighbors by default.
(Default: None)
delta : float, optional
A value controls the early abort. This function will abort if
:math:`k * N * delta > c`, where :math:`N` is the number of points,
:math:`c` is the number of updates during last iteration.
(Default: 0.001)
sample_rate : float, optional
A value controls how many candidates sampled. It should be a float value
between 0 and 1. Larger values will provide higher accuracy and converge
speed but with higher time cost.
(Default: 0.5)
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
Returns
-------
DGLGraph
The graph. The node IDs are in the same order as :attr:`x`.
"""
num_points, _ = F.shape(x)
if isinstance(segs, (tuple, list)):
segs = F.tensor(segs)
segs = F.copy_to(segs, F.context(x))
if max_candidates is None:
max_candidates = min(60, k)
if num_iters is None:
num_iters = max(10, int(round(np.log2(num_points))))
max_candidates = int(sample_rate * max_candidates)
# if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine':
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))
x = x / (l2_norm(x) + 1e-5)
# k must less than or equal to min(segs)
if k > F.min(segs, dim=0):
raise DGLError("'k' must be less than or equal to the number of points in 'x'"
"expect 'k' <= {}, got 'k' = {}".format(F.min(segs, dim=0), k))
if delta < 0 or delta > 1:
raise DGLError("'delta' must in [0, 1], got 'delta' = {}".format(delta))
offset = F.zeros((F.shape(segs)[0] + 1,), F.dtype(segs), F.context(segs))
offset[1:] = F.cumsum(segs, dim=0)
out = F.zeros((2, num_points * k), F.dtype(segs), F.context(segs))
# points, offsets, out, k, num_iters, max_candidates, delta
_CAPI_DGLNNDescent(F.to_dgl_nd(x), F.to_dgl_nd(offset),
F.zerocopy_to_dgl_ndarray_for_write(out),
k, num_iters, max_candidates, delta)
return out
def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclidean'):
r"""For each element in each segment in :attr:`y`, find :attr:`k` nearest
points in the same segment in :attr:`x`.
points in the same segment in :attr:`x`. If :attr:`y` is None, perform a self-query
over :attr:`x`.
This function allows multiple point sets with different capacity. The points
from different sets are stored contiguously in the :attr:`x` and :attr:`y` tensor.
......@@ -400,20 +493,22 @@ def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
Parameters
----------
k : int
The number of nearest neighbors per node.
x : Tensor
The point coordinates in x. It can be either on CPU or GPU (must be the
same as :attr:`y`). Must be 2D.
x_segs : Union[List[int], Tensor]
Number of points in each point set in :attr:`x`. The numbers in :attr:`x_segs`
must sum up to the number of rows in :attr:`x`.
y : Tensor
y : Tensor, optional
The point coordinates in y. It can be either on CPU or GPU (must be the
same as :attr:`x`). Must be 2D.
y_segs : Union[List[int], Tensor]
(default: None)
y_segs : Union[List[int], Tensor], optional
Number of points in each point set in :attr:`y`. The numbers in :attr:`y_segs`
must sum up to the number of rows in :attr:`y`.
k : int
The number of nearest neighbors per node.
(default: None)
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
......@@ -433,6 +528,12 @@ def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
Note: Currently, 'nn-descent' only supports self-query cases, i.e. :attr:`y` is None.
(default: 'bruteforce')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......@@ -448,6 +549,17 @@ def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
The first subtensor contains point indexs in :attr:`y`. The second subtensor contains
point indexs in :attr:`x`
"""
# TODO(lygztq) add support for querying different point sets using nn-descent.
if algorithm == "nn-descent":
if y is not None or y_segs is not None:
raise DGLError("Currently 'nn-descent' only supports self-query cases.")
return _nndescent_knn_graph(x, k, x_segs, dist=dist)
# self query
if y is None:
y = x
y_segs = x_segs
assert F.context(x) == F.context(y)
if isinstance(x_segs, (tuple, list)):
x_segs = F.tensor(x_segs)
......
This diff is collapsed.
This diff is collapsed.
......@@ -41,5 +41,30 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN")
});
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLNNDescent")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const NDArray points = args[0];
const IdArray offsets = args[1];
const IdArray result = args[2];
const int k = args[3];
const int num_iters = args[4];
const int num_candidates = args[5];
const double delta = args[6];
aten::CheckContiguous(
{points, offsets, result}, {"points", "offsets", "result"});
aten::CheckCtx(
points->ctx, {points, offsets, result}, {"points", "offsets", "result"});
ATEN_XPU_SWITCH_CUDA(points->ctx.device_type, XPU, "NNDescent", {
ATEN_FLOAT_TYPE_SWITCH(points->dtype, FloatType, "points", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
NNDescent<XPU, FloatType, IdType>(
points, offsets, result, k, num_iters, num_candidates, delta);
});
});
});
});
} // namespace transform
} // namespace dgl
......@@ -18,21 +18,37 @@ namespace transform {
* points in the same segment in \a data_points. \a data_offsets and \a query_offsets
* determine the start index of each segment in \a data_points and \a query_points.
*
* \param data_points dataset points
* \param data_offsets offsets of point index in \a data_points
* \param query_points query points
* \param query_offsets offsets of point index in \a query_points
* \param k the number of nearest points
* \param result output array
* \param algorithm algorithm used to compute the k-nearest neighbors
*
* \return A 2D tensor indicating the index relation between \a query_points and \a data_points.
* \param data_points dataset points.
* \param data_offsets offsets of point index in \a data_points.
* \param query_points query points.
* \param query_offsets offsets of point index in \a query_points.
* \param k the number of nearest points.
* \param result output array. A 2D tensor indicating the index
* relation between \a query_points and \a data_points.
* \param algorithm algorithm used to compute the k-nearest neighbors.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
/*!
* \brief For each input point, find \a k approximate nearest points in the same
* segment using NN-descent algorithm.
*
* \param points input points.
* \param offsets offsets of point index.
* \param result output array. A 2D tensor indicating the index relation between points.
* \param k the number of nearest points.
* \param num_iters The maximum number of NN-descent iterations to perform.
* \param num_candidates The maximum number of candidates to be considered during one iteration.
* \param delta A value controls the early abort.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
} // namespace transform
} // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment