knn.h 2.13 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file graph/transform/knn.h
 * @brief k-nearest-neighbor (KNN) implementation
5
6
7
8
9
10
 */

#ifndef DGL_GRAPH_TRANSFORM_KNN_H_
#define DGL_GRAPH_TRANSFORM_KNN_H_

#include <dgl/array.h>
11

12
13
14
15
16
#include <string>

namespace dgl {
namespace transform {

17
/**
18
 * @brief For each point in each segment in \a query_points, find \a k nearest
19
20
21
 *        points in the same segment in \a data_points. \a data_offsets and \a
 *        query_offsets determine the start index of each segment in \a
 *        data_points and \a query_points.
22
 *
23
24
25
26
27
28
 * @param data_points dataset points.
 * @param data_offsets offsets of point index in \a data_points.
 * @param query_points query points.
 * @param query_offsets offsets of point index in \a query_points.
 * @param k the number of nearest points.
 * @param result output array. A 2D tensor indicating the index  relation
29
 *        between \a query_points and \a data_points.
30
 * @param algorithm algorithm used to compute the k-nearest neighbors.
31
 */
32
template <DGLDeviceType XPU, typename FloatType, typename IdType>
33
34
35
36
void KNN(
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
37

38
/**
39
 * @brief For each input point, find \a k approximate nearest points in the same
40
 *        segment using NN-descent algorithm.
41
 *
42
43
44
 * @param points input points.
 * @param offsets offsets of point index.
 * @param result output array. A 2D tensor indicating the index relation between
45
 *        points.
46
47
48
 * @param k the number of nearest points.
 * @param num_iters The maximum number of NN-descent iterations to perform.
 * @param num_candidates The maximum number of candidates to be considered
49
 *        during one iteration.
50
 * @param delta A value controls the early abort.
51
 */
52
template <DGLDeviceType XPU, typename FloatType, typename IdType>
53
54
55
void NNDescent(
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
56

57
58
59
60
}  // namespace transform
}  // namespace dgl

#endif  // DGL_GRAPH_TRANSFORM_KNN_H_