"git@developer.sourcefind.cn:change/sglang.git" did not exist on "bad7c26fdc7fb87099dc833ebf8ff873cef5170b"
Unverified Commit e83d0a80 authored by Tianqi Zhang (张天启)'s avatar Tianqi Zhang (张天启) Committed by GitHub
Browse files

[Feature] Add kd-tree implementation (CPU) for kNN (#2767)



* add submodule nanoflann

* finish python API for knn

* finish ndarray adaptor

* finish cpu-kdtree version of knn

* use openmp

* add endline

* upt

* upt

* fix format and code style

* upt

* add warning for gpu-cpu copy

* avoid contiguous copy
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent 0fce0907
......@@ -29,3 +29,6 @@
[submodule "third_party/tvm"]
path = third_party/tvm
url = https://github.com/apache/incubator-tvm
[submodule "third_party/nanoflann"]
path = third_party/nanoflann
url = https://github.com/jlblancoc/nanoflann
......@@ -175,6 +175,7 @@ target_include_directories(dgl PRIVATE "third_party/phmap/")
target_include_directories(dgl PRIVATE "third_party/xbyak/")
target_include_directories(dgl PRIVATE "third_party/METIS/include/")
target_include_directories(dgl PRIVATE "tensoradapter/include")
target_include_directories(dgl PRIVATE "third_party/nanoflann/include")
# For serialization
if (USE_HDFS)
......
......@@ -67,7 +67,7 @@ class KNNGraph(nn.Module):
self.k = k
#pylint: disable=invalid-name
def forward(self, x):
def forward(self, x, algorithm='topk'):
"""
Forward computation.
......@@ -78,13 +78,21 @@ class KNNGraph(nn.Module):
:math:`(M, D)` or :math:`(N, M, D)` where :math:`N` means the
number of point sets, :math:`M` means the number of points in
each point set, and :math:`D` means the size of features.
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
* 'topk' will use topk algorithm (quick-select or sorting,
depending on backend implementation)
* 'kd-tree' will use kd-tree algorithm (cpu version)
(default: 'topk')
Returns
-------
DGLGraph
A DGLGraph without features.
"""
return knn_graph(x, self.k)
return knn_graph(x, self.k, algorithm)
class SegmentedKNNGraph(nn.Module):
......@@ -140,7 +148,7 @@ class SegmentedKNNGraph(nn.Module):
self.k = k
#pylint: disable=invalid-name
def forward(self, x, segs):
def forward(self, x, segs, algorithm='topk'):
r"""Forward computation.
Parameters
......@@ -152,6 +160,14 @@ class SegmentedKNNGraph(nn.Module):
:math:`(N)` integers where :math:`N` means the number of point
sets. The number of elements must sum up to :math:`M`. And any
:math:`N` should :math:`\ge k`
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
* 'topk' will use topk algorithm (quick-select or sorting,
depending on backend implementation)
* 'kd-tree' will use kd-tree algorithm (cpu version)
(default: 'topk')
Returns
-------
......@@ -159,4 +175,4 @@ class SegmentedKNNGraph(nn.Module):
A DGLGraph without features.
"""
return segmented_knn_graph(x, self.k, segs)
return segmented_knn_graph(x, self.k, segs, algorithm)
......@@ -59,7 +59,7 @@ def pairwise_squared_distance(x):
return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)
#pylint: disable=invalid-name
def knn_graph(x, k):
def knn_graph(x, k, algorithm='topk'):
"""Construct a graph from a set of points according to k-nearest-neighbor (KNN)
and return.
......@@ -86,6 +86,14 @@ def knn_graph(x, k):
``x[i][j]`` corresponds to the j-th node in the i-th KNN graph.
k : int
The number of nearest neighbors per node.
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
* 'topk' will use topk algorithm (quick-select or sorting,
depending on backend implementation)
* 'kd-tree' will use kd-tree algorithm (only on cpu)
(default: 'topk')
Returns
-------
......@@ -129,6 +137,35 @@ def knn_graph(x, k):
(tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7]),
tensor([0, 1, 1, 2, 3, 0, 2, 3, 4, 5, 6, 7, 4, 6, 5, 7]))
"""
if algorithm == 'topk':
return _knn_graph_topk(x, k)
else:
if F.ndim(x) == 3:
x_size = tuple(F.shape(x))
x = F.reshape(x, (x_size[0] * x_size[1], x_size[2]))
x_seg = x_size[0] * [x_size[1]]
else:
x_seg = [F.shape(x)[0]]
out = knn(x, x_seg, x, x_seg, k, algorithm=algorithm)
row, col = out[1], out[0]
return convert.graph((row, col))
def _knn_graph_topk(x, k):
"""Construct a graph from a set of points according to k-nearest-neighbor (KNN)
via topk method.
Parameters
----------
x : Tensor
The point coordinates. It can be either on CPU or GPU.
* If is 2D, ``x[i]`` corresponds to the i-th node in the KNN graph.
* If is 3D, ``x[i]`` corresponds to the i-th KNN graph and
``x[i][j]`` corresponds to the j-th node in the i-th KNN graph.
k : int
The number of nearest neighbors per node.
"""
if F.ndim(x) == 2:
x = F.unsqueeze(x, 0)
n_samples, n_points, _ = F.shape(x)
......@@ -147,11 +184,10 @@ def knn_graph(x, k):
adj = sparse.csr_matrix(
(F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src))),
shape=(n_samples * n_points, n_samples * n_points))
return convert.from_scipy(adj)
#pylint: disable=invalid-name
def segmented_knn_graph(x, k, segs):
def segmented_knn_graph(x, k, segs, algorithm='topk'):
"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN) and return.
......@@ -173,6 +209,14 @@ def segmented_knn_graph(x, k, segs):
segs : list[int]
Number of points in each point set. The numbers in :attr:`segs`
must sum up to the number of rows in :attr:`x`.
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
* 'topk' will use topk algorithm (quick-select or sorting,
depending on backend implementation)
* 'kd-tree' will use kd-tree algorithm (only on cpu)
(default: 'topk')
Returns
-------
......@@ -208,6 +252,27 @@ def segmented_knn_graph(x, k, segs):
(tensor([0, 0, 1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6]),
tensor([0, 1, 0, 1, 2, 2, 3, 5, 4, 6, 3, 5, 4, 6]))
"""
if algorithm == 'topk':
return _segmented_knn_graph_topk(x, k, segs)
else:
out = knn(x, segs, x, segs, k, algorithm=algorithm)
row, col = out[1], out[0]
return convert.graph((row, col))
def _segmented_knn_graph_topk(x, k, segs):
"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN) via topk method.
Parameters
----------
x : Tensor
Coordinates/features of points. Must be 2D. It can be either on CPU or GPU.
k : int
The number of nearest neighbors per node.
segs : list[int]
Number of points in each point set. The numbers in :attr:`segs`
must sum up to the number of rows in :attr:`x`.
"""
n_total_points, _ = F.shape(x)
offset = np.insert(np.cumsum(segs), 0, 0)
......@@ -225,6 +290,97 @@ def segmented_knn_graph(x, k, segs):
return convert.from_scipy(adj)
def knn(x, x_segs, y, y_segs, k, algorithm='kd-tree', dist='euclidean'):
r"""For each element in each segment in :attr:`y`, find :attr:`k` nearest
points in the same segment in :attr:`x`.
This function allows multiple point sets with different capacity. The points
from different sets are stored contiguously in the :attr:`x` and :attr:`y` tensor.
:attr:`x_segs` and :attr:`y_segs` specifies the number of points in each point set.
Parameters
----------
x : Tensor
The point coordinates in x. It can be either on CPU or GPU (must be the
same as :attr:`y`). Must be 2D.
x_segs : Union[List[int], Tensor]
Number of points in each point set in :attr:`x`. The numbers in :attr:`x_segs`
must sum up to the number of rows in :attr:`x`.
y : Tensor
The point coordinates in y. It can be either on CPU or GPU (must be the
same as :attr:`x`). Must be 2D.
y_segs : Union[List[int], Tensor]
Number of points in each point set in :attr:`y`. The numbers in :attr:`y_segs`
must sum up to the number of rows in :attr:`y`.
k : int
The number of nearest neighbors per node.
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
Currently only cpu version kdtree is supported.
(default: 'kd-tree')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(Default: "euclidean")
Returns
-------
Tensor
Tensor with size `(2, k * num_points(y))`
The first subtensor contains point indexs in :attr:`y`. The second subtensor contains
point indexs in :attr:`x`
"""
# currently only cpu implementation is supported.
if (F.context(x) != F.cpu() or F.context(y) != F.cpu()):
dgl_warning("Currently only cpu implementation is supported," \
"copy input tensors to cpu.")
x = F.copy_to(x, F.cpu())
y = F.copy_to(y, F.cpu())
if isinstance(x_segs, (tuple, list)):
x_segs = F.tensor(x_segs)
if isinstance(y_segs, (tuple, list)):
y_segs = F.tensor(y_segs)
x_segs = F.copy_to(x_segs, F.context(x))
y_segs = F.copy_to(y_segs, F.context(y))
# supported algorithms
algorithm_list = ['kd-tree']
if algorithm not in algorithm_list:
raise DGLError("only {} algorithms are supported, get '{}'".format(
algorithm_list, algorithm))
# k must less than or equal to min(x_segs)
if k > F.min(x_segs, dim=0):
raise DGLError("'k' must be less than or equal to the number of points in 'x'"
"expect k <= {}, got k = {}".format(F.min(x_segs, dim=0), k))
dist = dist.lower()
dist_metric_list = ['euclidean', 'cosine']
if dist not in dist_metric_list:
raise DGLError('Only {} are supported for distance'
'computation, got {}'.format(dist_metric_list, dist))
x_offset = F.zeros((F.shape(x_segs)[0] + 1,), F.dtype(x_segs), F.context(x_segs))
x_offset[1:] = F.cumsum(x_segs, dim=0)
y_offset = F.zeros((F.shape(y_segs)[0] + 1,), F.dtype(y_segs), F.context(y_segs))
y_offset[1:] = F.cumsum(y_segs, dim=0)
out = F.zeros((2, F.shape(y)[0] * k), F.dtype(x_segs), F.context(x_segs))
# if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine':
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))
x = x / (l2_norm(x) + 1e-5)
y = y / (l2_norm(y) + 1e-5)
_CAPI_DGLKNN(F.to_dgl_nd(x), F.to_dgl_nd(x_offset),
F.to_dgl_nd(y), F.to_dgl_nd(y_offset),
k, F.zerocopy_to_dgl_ndarray_for_write(out),
algorithm)
return out
def to_bidirected(g, copy_ndata=False, readonly=None):
r"""Convert the graph to a bi-directional simple graph and return.
......
/*!
* Copyright (c) 2021 by Contributors
* \file graph/transform/kdtree_ndarray_adapter.h
* \brief NDArray adapter for nanoflann, without
* duplicating the storage
*/
#ifndef DGL_GRAPH_TRANSFORM_KDTREE_NDARRAY_ADAPTER_H_
#define DGL_GRAPH_TRANSFORM_KDTREE_NDARRAY_ADAPTER_H_
#include <dgl/array.h>
#include <dmlc/logging.h>
#include <nanoflann.hpp>
#include "../../c_api_common.h"
namespace dgl {
namespace transform {
namespace knn_utils {
/*!
* \brief A simple 2D NDArray adapter for nanoflann, without duplicating the storage.
*
* \tparam FloatType: The type of the point coordinates (typically, double or float).
* \tparam IdType: The type for indices in the KD-tree index (typically, size_t of int)
* \tparam FeatureDim: If set to > 0, it specifies a compile-time fixed dimensionality
* for the points in the data set, allowing more compiler optimizations.
* \tparam Dist: The distance metric to use: nanoflann::metric_L1, nanoflann::metric_L2,
* nanoflann::metric_L2_Simple, etc.
* \note The spelling of dgl's adapter ("adapter") is different from naneflann ("adaptor")
*/
template <typename FloatType,
typename IdType,
int FeatureDim = -1,
typename Dist = nanoflann::metric_L2>
class KDTreeNDArrayAdapter {
public:
using self_type = KDTreeNDArrayAdapter<FloatType, IdType, FeatureDim, Dist>;
using metric_type = typename Dist::template traits<FloatType, self_type>::distance_t;
using index_type = nanoflann::KDTreeSingleIndexAdaptor<
metric_type, self_type, FeatureDim, IdType>;
KDTreeNDArrayAdapter(const size_t /* dims */,
const NDArray data_points,
const int leaf_max_size = 10)
: data_(data_points) {
CHECK(data_points->shape[0] != 0 && data_points->shape[1] != 0)
<< "Tensor containing input data point set must be 2D.";
const size_t dims = data_points->shape[1];
CHECK(!(FeatureDim > 0 && static_cast<int>(dims) != FeatureDim))
<< "Data set feature dimension does not match the 'FeatureDim' "
<< "template argument.";
index_ = new index_type(
static_cast<int>(dims), *this, nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size));
index_->buildIndex();
}
~KDTreeNDArrayAdapter() {
delete index_;
}
index_type* GetIndex() {
return index_;
}
/*!
* \brief Query for the \a num_closest points to a given point
* Note that this is a short-cut method for GetIndex()->findNeighbors().
*/
void query(const FloatType* query_pt, const size_t num_closest,
IdType* out_idxs, FloatType* out_dists) const {
nanoflann::KNNResultSet<FloatType, IdType> resultSet(num_closest);
resultSet.init(out_idxs, out_dists);
index_->findNeighbors(resultSet, query_pt, nanoflann::SearchParams());
}
/*! \brief Interface expected by KDTreeSingleIndexAdaptor */
const self_type& derived() const {
return *this;
}
/*! \brief Interface expected by KDTreeSingleIndexAdaptor */
self_type& derived() {
return *this;
}
/*!
* \brief Interface expected by KDTreeSingleIndexAdaptor,
* return the number of data points
*/
size_t kdtree_get_point_count() const {
return data_->shape[0];
}
/*!
* \brief Interface expected by KDTreeSingleIndexAdaptor,
* return the dim'th component of the idx'th point
*/
FloatType kdtree_get_pt(const size_t idx, const size_t dim) const {
return data_.Ptr<FloatType>()[idx * data_->shape[1] + dim];
}
/*!
* \brief Interface expected by KDTreeSingleIndexAdaptor.
* Optional bounding-box computation: return false to
* default to a standard bbox computation loop.
*
*/
template <typename BBOX>
bool kdtree_get_bbox(BBOX& /* bb */) const {
return false;
}
private:
index_type* index_; // The kd tree index
const NDArray data_; // data points
};
} // namespace knn_utils
} // namespace transform
} // namespace dgl
#endif // DGL_GRAPH_TRANSFORM_KDTREE_NDARRAY_ADAPTER_H_
/*!
* Copyright (c) 2019 by Contributors
* \file graph/transform/knn.cc
* \brief k-nearest-neighbor (KNN) implementation
*/
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <vector>
#include "kdtree_ndarray_adapter.h"
#include "knn.h"
#include "../../array/check.h"
using namespace dgl::runtime;
using namespace dgl::transform::knn_utils;
namespace dgl {
namespace transform {
namespace impl {
/*! \brief The kd-tree implementation of K-Nearest Neighbors */
template <typename FloatType, typename IdType>
void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) {
int64_t batch_size = data_offsets->shape[0] - 1;
int64_t feature_size = data_points->shape[1];
const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
const FloatType* query_points_data = query_points.Ptr<FloatType>();
IdType* query_out = result.Ptr<IdType>();
IdType* data_out = query_out + k * query_points->shape[0];
for (int64_t b = 0; b < batch_size; ++b) {
auto d_offset = data_offsets_data[b];
auto d_length = data_offsets_data[b + 1] - d_offset;
auto q_offset = query_offsets_data[b];
auto q_length = query_offsets_data[b + 1] - q_offset;
auto out_offset = k * q_offset;
// create view for each segment
const NDArray current_data_points = const_cast<NDArray*>(&data_points)->CreateView(
{d_length, feature_size}, data_points->dtype, d_offset * feature_size * sizeof(FloatType));
const FloatType* current_query_pts_data = query_points_data + q_offset * feature_size;
KDTreeNDArrayAdapter<FloatType, IdType> kdtree(feature_size, current_data_points);
// query
std::vector<IdType> out_buffer(k);
std::vector<FloatType> out_dist_buffer(k);
#pragma omp parallel for firstprivate(out_buffer) firstprivate(out_dist_buffer)
for (int64_t q = 0; q < q_length; ++q) {
auto curr_out_offset = k * q + out_offset;
const FloatType* q_point = current_query_pts_data + q * feature_size;
size_t num_matches = kdtree.GetIndex()->knnSearch(
q_point, k, out_buffer.data(), out_dist_buffer.data());
for (size_t i = 0; i < num_matches; ++i) {
query_out[curr_out_offset] = q + q_offset;
data_out[curr_out_offset] = out_buffer[i] + d_offset;
curr_out_offset++;
}
}
}
}
} // namespace impl
template <DLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm) {
if (algorithm == std::string("kd-tree")) {
impl::KdTreeKNN<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result);
} else {
LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CPU";
}
}
template void KNN<kDLCPU, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, float, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, double, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const NDArray data_points = args[0];
const IdArray data_offsets = args[1];
const NDArray query_points = args[2];
const IdArray query_offsets = args[3];
const int k = args[4];
IdArray result = args[5];
const std::string algorithm = args[6];
aten::CheckContiguous(
{data_points, data_offsets, query_points, query_offsets, result},
{"data_points", "data_offsets", "query_points", "query_offsets", "result"});
aten::CheckCtx(
data_points->ctx, {data_offsets, query_points, query_offsets, result},
{"data_offsets", "query_points", "query_offsets", "result"});
ATEN_XPU_SWITCH(data_points->ctx.device_type, XPU, "KNN", {
ATEN_FLOAT_TYPE_SWITCH(data_points->dtype, FloatType, "data_points", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
KNN<XPU, FloatType, IdType>(
data_points, data_offsets, query_points,
query_offsets, k, result, algorithm);
});
});
});
});
} // namespace transform
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file graph/transform/knn.h
* \brief k-nearest-neighbor (KNN) implementation
*/
#ifndef DGL_GRAPH_TRANSFORM_KNN_H_
#define DGL_GRAPH_TRANSFORM_KNN_H_
#include <dgl/array.h>
#include <string>
namespace dgl {
namespace transform {
/*!
* \brief For each point in each segment in \a query_points, find \a k nearest
* points in the same segment in \a data_points. \a data_offsets and \a query_offsets
* determine the start index of each segment in \a data_points and \a query_points.
*
* \param data_points dataset points
* \param data_offsets offsets of point index in \a data_points
* \param query_points query points
* \param query_offsets offsets of point index in \a query_points
* \param k the number of nearest points
* \param result output array
* \param algorithm algorithm used to compute the k-nearest neighbors
*
* \return A 2D tensor indicating the index relation between \a query_points and \a data_points.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string & algorithm);
} // namespace transform
} // namespace dgl
#endif // DGL_GRAPH_TRANSFORM_KNN_H_
......@@ -24,7 +24,9 @@ def test_fps():
assert res.shape[1] == sample_points
assert res.sum() > 0
def test_knn():
@pytest.mark.parametrize('algorithm', ['topk', 'kd-tree'])
def test_knn(algorithm):
x = th.randn(8, 3)
kg = dgl.nn.KNNGraph(3)
d = th.cdist(x, x)
......@@ -37,15 +39,15 @@ def test_knn():
src_ans = set(th.topk(d[start:end, start:end][i], 3, largest=False)[1].numpy() + start)
assert src == src_ans
g = kg(x)
g = kg(x, algorithm)
check_knn(g, x, 0, 8)
g = kg(x.view(2, 4, 3))
g = kg(x.view(2, 4, 3), algorithm)
check_knn(g, x, 0, 4)
check_knn(g, x, 4, 8)
kg = dgl.nn.SegmentedKNNGraph(3)
g = kg(x, [3, 5])
g = kg(x, [3, 5], algorithm)
check_knn(g, x, 0, 3)
check_knn(g, x, 3, 8)
......
Subproject commit 4c47ca200209550c5628c89803591f8a753c8181
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment