kdtree_ndarray_adapter.h 3.94 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2021 by Contributors
3
 * \file graph/transform/cpu/kdtree_ndarray_adapter.h
4
5
6
 * \brief NDArray adapter for nanoflann, without
 *        duplicating the storage
 */
7
8
#ifndef DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_
#define DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_
9
10
11

#include <dgl/array.h>
#include <dmlc/logging.h>
12

13
#include <nanoflann.hpp>
14

15
#include "../../../c_api_common.h"
16
17
18
19
20
21

namespace dgl {
namespace transform {
namespace knn_utils {

/*!
22
23
 * \brief A simple 2D NDArray adapter for nanoflann, without duplicating the
 *        storage.
24
 *
25
26
27
28
29
30
31
32
33
34
35
 * \tparam FloatType: The type of the point coordinates (typically, double or
 *         float).
 * \tparam IdType: The type for indices in the KD-tree index (typically,
 *         size_t of int)
 * \tparam FeatureDim: If set to > 0, it specifies a compile-time fixed
 *         dimensionality for the points in the data set, allowing more compiler
 *         optimizations.
 * \tparam Dist: The distance metric to use: nanoflann::metric_L1,
           nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc.
 * \note The spelling of dgl's adapter ("adapter") is different from naneflann
 *       ("adaptor")
36
 */
37
38
39
template <
    typename FloatType, typename IdType, int FeatureDim = -1,
    typename Dist = nanoflann::metric_L2>
40
41
42
class KDTreeNDArrayAdapter {
 public:
  using self_type = KDTreeNDArrayAdapter<FloatType, IdType, FeatureDim, Dist>;
43
44
  using metric_type =
      typename Dist::template traits<FloatType, self_type>::distance_t;
45
  using index_type = nanoflann::KDTreeSingleIndexAdaptor<
46
      metric_type, self_type, FeatureDim, IdType>;
47

48
49
50
  KDTreeNDArrayAdapter(
      const size_t /* dims */, const NDArray data_points,
      const int leaf_max_size = 10)
51
52
      : data_(data_points) {
    CHECK(data_points->shape[0] != 0 && data_points->shape[1] != 0)
53
        << "Tensor containing input data point set must be 2D.";
54
55
    const size_t dims = data_points->shape[1];
    CHECK(!(FeatureDim > 0 && static_cast<int>(dims) != FeatureDim))
56
57
        << "Data set feature dimension does not match the 'FeatureDim' "
        << "template argument.";
58
    index_ = new index_type(
59
60
        static_cast<int>(dims), *this,
        nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size));
61
62
63
    index_->buildIndex();
  }

64
  ~KDTreeNDArrayAdapter() { delete index_; }
65

66
  index_type* GetIndex() { return index_; }
67
68
69
70
71

  /*!
   * \brief Query for the \a num_closest points to a given point
   *  Note that this is a short-cut method for GetIndex()->findNeighbors().
   */
72
73
74
  void query(
      const FloatType* query_pt, const size_t num_closest, IdType* out_idxs,
      FloatType* out_dists) const {
75
76
77
78
79
80
    nanoflann::KNNResultSet<FloatType, IdType> resultSet(num_closest);
    resultSet.init(out_idxs, out_dists);
    index_->findNeighbors(resultSet, query_pt, nanoflann::SearchParams());
  }

  /*! \brief Interface expected by KDTreeSingleIndexAdaptor */
81
  const self_type& derived() const { return *this; }
82
83

  /*! \brief Interface expected by KDTreeSingleIndexAdaptor */
84
  self_type& derived() { return *this; }
85
86
87
88
89

  /*!
   * \brief Interface expected by KDTreeSingleIndexAdaptor,
   *  return the number of data points
   */
90
  size_t kdtree_get_point_count() const { return data_->shape[0]; }
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

  /*!
   * \brief Interface expected by KDTreeSingleIndexAdaptor,
   *  return the dim'th component of the idx'th point
   */
  FloatType kdtree_get_pt(const size_t idx, const size_t dim) const {
    return data_.Ptr<FloatType>()[idx * data_->shape[1] + dim];
  }

  /*!
   * \brief Interface expected by KDTreeSingleIndexAdaptor.
   *  Optional bounding-box computation: return false to
   *  default to a standard bbox computation loop.
   *
   */
  template <typename BBOX>
  bool kdtree_get_bbox(BBOX& /* bb */) const {
    return false;
  }

 private:
112
  index_type* index_;   // The kd tree index
113
114
115
116
117
118
119
  const NDArray data_;  // data points
};

}  // namespace knn_utils
}  // namespace transform
}  // namespace dgl

120
#endif  // DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_