/*! * Copyright (c) 2021 by Contributors * \file graph/transform/cpu/kdtree_ndarray_adapter.h * \brief NDArray adapter for nanoflann, without * duplicating the storage */ #ifndef DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_ #define DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_ #include #include #include #include "../../../c_api_common.h" namespace dgl { namespace transform { namespace knn_utils { /*! * \brief A simple 2D NDArray adapter for nanoflann, without duplicating the * storage. * * \tparam FloatType: The type of the point coordinates (typically, double or * float). * \tparam IdType: The type for indices in the KD-tree index (typically, * size_t of int) * \tparam FeatureDim: If set to > 0, it specifies a compile-time fixed * dimensionality for the points in the data set, allowing more compiler * optimizations. * \tparam Dist: The distance metric to use: nanoflann::metric_L1, nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. * \note The spelling of dgl's adapter ("adapter") is different from naneflann * ("adaptor") */ template < typename FloatType, typename IdType, int FeatureDim = -1, typename Dist = nanoflann::metric_L2> class KDTreeNDArrayAdapter { public: using self_type = KDTreeNDArrayAdapter; using metric_type = typename Dist::template traits::distance_t; using index_type = nanoflann::KDTreeSingleIndexAdaptor< metric_type, self_type, FeatureDim, IdType>; KDTreeNDArrayAdapter( const size_t /* dims */, const NDArray data_points, const int leaf_max_size = 10) : data_(data_points) { CHECK(data_points->shape[0] != 0 && data_points->shape[1] != 0) << "Tensor containing input data point set must be 2D."; const size_t dims = data_points->shape[1]; CHECK(!(FeatureDim > 0 && static_cast(dims) != FeatureDim)) << "Data set feature dimension does not match the 'FeatureDim' " << "template argument."; index_ = new index_type( static_cast(dims), *this, nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size)); index_->buildIndex(); } ~KDTreeNDArrayAdapter() { delete index_; } index_type* GetIndex() { return index_; } /*! * \brief Query for the \a num_closest points to a given point * Note that this is a short-cut method for GetIndex()->findNeighbors(). */ void query( const FloatType* query_pt, const size_t num_closest, IdType* out_idxs, FloatType* out_dists) const { nanoflann::KNNResultSet resultSet(num_closest); resultSet.init(out_idxs, out_dists); index_->findNeighbors(resultSet, query_pt, nanoflann::SearchParams()); } /*! \brief Interface expected by KDTreeSingleIndexAdaptor */ const self_type& derived() const { return *this; } /*! \brief Interface expected by KDTreeSingleIndexAdaptor */ self_type& derived() { return *this; } /*! * \brief Interface expected by KDTreeSingleIndexAdaptor, * return the number of data points */ size_t kdtree_get_point_count() const { return data_->shape[0]; } /*! * \brief Interface expected by KDTreeSingleIndexAdaptor, * return the dim'th component of the idx'th point */ FloatType kdtree_get_pt(const size_t idx, const size_t dim) const { return data_.Ptr()[idx * data_->shape[1] + dim]; } /*! * \brief Interface expected by KDTreeSingleIndexAdaptor. * Optional bounding-box computation: return false to * default to a standard bbox computation loop. * */ template bool kdtree_get_bbox(BBOX& /* bb */) const { return false; } private: index_type* index_; // The kd tree index const NDArray data_; // data points }; } // namespace knn_utils } // namespace transform } // namespace dgl #endif // DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_