Unverified Commit 64d0f3f3 authored by Tianqi Zhang (张天启)'s avatar Tianqi Zhang (张天启) Committed by GitHub
Browse files

[Feature] Add NN-descent support for the KNN graph function in dgl (#2941)



* add bruteforce impl

* add nn descent implementation

* change doc-string

* remove redundant func

* use local rng for cuda

* fix lint

* fix lint

* fix bug

* fix bug

* wrap nndescent_knn_graph into knn

* fix lint

* change function names

* add comment for dist funcs

* let the compiler do the unrolling

* use better blocksize setting

* remove redundant line

* check the return of the cub calls
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent a303f078
......@@ -104,6 +104,11 @@ class KNNGraph(nn.Module):
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......@@ -212,6 +217,11 @@ class SegmentedKNNGraph(nn.Module):
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......
......@@ -117,6 +117,11 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......@@ -182,7 +187,7 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
x_seg = x_size[0] * [x_size[1]]
else:
x_seg = [F.shape(x)[0]]
out = knn(x, x_seg, x, x_seg, k, algorithm=algorithm, dist=dist)
out = knn(k, x, x_seg, algorithm=algorithm, dist=dist)
row, col = out[1], out[0]
return convert.graph((row, col))
......@@ -287,6 +292,11 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......@@ -338,7 +348,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
if algorithm == 'bruteforce-blas':
return _segmented_knn_graph_blas(x, k, segs, dist=dist)
else:
out = knn(x, segs, x, segs, k, algorithm=algorithm, dist=dist)
out = knn(k, x, segs, algorithm=algorithm, dist=dist)
row, col = out[1], out[0]
return convert.graph((row, col))
......@@ -390,9 +400,92 @@ def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'):
dst = F.repeat(F.arange(0, n_total_points, ctx=ctx), k, dim=0)
return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
def _nndescent_knn_graph(x, k, segs, num_iters=None, max_candidates=None,
delta=0.001, sample_rate=0.5, dist='euclidean'):
r"""Construct multiple graphs from multiple sets of points according to
**approximate** k-nearest-neighbor using NN-descent algorithm from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_.
Parameters
----------
x : Tensor
Coordinates/features of points. Must be 2D. It can be either on CPU or GPU.
k : int
The number of nearest neighbors per node.
segs : list[int]
Number of points in each point set. The numbers in :attr:`segs`
must sum up to the number of rows in :attr:`x`.
num_iters : int, optional
The maximum number of NN-descent iterations to perform. A value will be
chosen based on the size of input by default.
(Default: None)
max_candidates : int, optional
The maximum number of candidates to be considered during one iteration.
Larger values will provide more accurate search results later, but
potentially at non-negligible computation cost. A value will be chosen
based on the number of neighbors by default.
(Default: None)
delta : float, optional
A value controls the early abort. This function will abort if
:math:`k * N * delta > c`, where :math:`N` is the number of points,
:math:`c` is the number of updates during last iteration.
(Default: 0.001)
sample_rate : float, optional
A value controls how many candidates sampled. It should be a float value
between 0 and 1. Larger values will provide higher accuracy and converge
speed but with higher time cost.
(Default: 0.5)
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
Returns
-------
DGLGraph
The graph. The node IDs are in the same order as :attr:`x`.
"""
num_points, _ = F.shape(x)
if isinstance(segs, (tuple, list)):
segs = F.tensor(segs)
segs = F.copy_to(segs, F.context(x))
if max_candidates is None:
max_candidates = min(60, k)
if num_iters is None:
num_iters = max(10, int(round(np.log2(num_points))))
max_candidates = int(sample_rate * max_candidates)
# if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine':
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))
x = x / (l2_norm(x) + 1e-5)
# k must less than or equal to min(segs)
if k > F.min(segs, dim=0):
raise DGLError("'k' must be less than or equal to the number of points in 'x'"
"expect 'k' <= {}, got 'k' = {}".format(F.min(segs, dim=0), k))
if delta < 0 or delta > 1:
raise DGLError("'delta' must in [0, 1], got 'delta' = {}".format(delta))
offset = F.zeros((F.shape(segs)[0] + 1,), F.dtype(segs), F.context(segs))
offset[1:] = F.cumsum(segs, dim=0)
out = F.zeros((2, num_points * k), F.dtype(segs), F.context(segs))
# points, offsets, out, k, num_iters, max_candidates, delta
_CAPI_DGLNNDescent(F.to_dgl_nd(x), F.to_dgl_nd(offset),
F.zerocopy_to_dgl_ndarray_for_write(out),
k, num_iters, max_candidates, delta)
return out
def knn(k, x, x_segs, y=None, y_segs=None, algorithm='bruteforce', dist='euclidean'):
r"""For each element in each segment in :attr:`y`, find :attr:`k` nearest
points in the same segment in :attr:`x`.
points in the same segment in :attr:`x`. If :attr:`y` is None, perform a self-query
over :attr:`x`.
This function allows multiple point sets with different capacity. The points
from different sets are stored contiguously in the :attr:`x` and :attr:`y` tensor.
......@@ -400,20 +493,22 @@ def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
Parameters
----------
k : int
The number of nearest neighbors per node.
x : Tensor
The point coordinates in x. It can be either on CPU or GPU (must be the
same as :attr:`y`). Must be 2D.
x_segs : Union[List[int], Tensor]
Number of points in each point set in :attr:`x`. The numbers in :attr:`x_segs`
must sum up to the number of rows in :attr:`x`.
y : Tensor
y : Tensor, optional
The point coordinates in y. It can be either on CPU or GPU (must be the
same as :attr:`x`). Must be 2D.
y_segs : Union[List[int], Tensor]
(default: None)
y_segs : Union[List[int], Tensor], optional
Number of points in each point set in :attr:`y`. The numbers in :attr:`y_segs`
must sum up to the number of rows in :attr:`y`.
k : int
The number of nearest neighbors per node.
(default: None)
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
......@@ -433,6 +528,12 @@ def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
* 'nn-descent' is a approximate approach from paper
`Efficient k-nearest neighbor graph construction for generic similarity
measures <https://www.cs.princeton.edu/cass/papers/www11.pdf>`_. This method
will search for nearest neighbor candidates in "neighbors' neighbors".
Note: Currently, 'nn-descent' only supports self-query cases, i.e. :attr:`y` is None.
(default: 'bruteforce')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
......@@ -448,6 +549,17 @@ def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
The first subtensor contains point indexs in :attr:`y`. The second subtensor contains
point indexs in :attr:`x`
"""
# TODO(lygztq) add support for querying different point sets using nn-descent.
if algorithm == "nn-descent":
if y is not None or y_segs is not None:
raise DGLError("Currently 'nn-descent' only supports self-query cases.")
return _nndescent_knn_graph(x, k, x_segs, dist=dist)
# self query
if y is None:
y = x
y_segs = x_segs
assert F.context(x) == F.context(y)
if isinstance(x_segs, (tuple, list)):
x_segs = F.tensor(x_segs)
......
......@@ -4,8 +4,13 @@
* \brief k-nearest-neighbor (KNN) implementation
*/
#include <dgl/runtime/device_api.h>
#include <dgl/random.h>
#include <dmlc/omp.h>
#include <vector>
#include <tuple>
#include <limits>
#include <algorithm>
#include "kdtree_ndarray_adapter.h"
#include "../knn.h"
......@@ -15,6 +20,192 @@ namespace dgl {
namespace transform {
namespace impl {
// This value is directly from pynndescent
static constexpr int NN_DESCENT_BLOCK_SIZE = 16384;
/*!
* \brief Compute Euclidean distance between two vectors, return positive
* infinite value if the intermediate distance is greater than the worst
* distance.
*/
template <typename FloatType, typename IdType>
FloatType EuclideanDistWithCheck(const FloatType* vec1, const FloatType* vec2, int64_t dim,
FloatType worst_dist = std::numeric_limits<FloatType>::max()) {
FloatType dist = 0;
bool early_stop = false;
for (IdType idx = 0; idx < dim; ++idx) {
dist += (vec1[idx] - vec2[idx]) * (vec1[idx] - vec2[idx]);
if (dist > worst_dist) {
early_stop = true;
break;
}
}
if (early_stop) {
return std::numeric_limits<FloatType>::max();
} else {
return dist;
}
}
/*! \brief Compute Euclidean distance between two vectors */
template <typename FloatType, typename IdType>
FloatType EuclideanDist(const FloatType* vec1, const FloatType* vec2, int64_t dim) {
FloatType dist = 0;
for (IdType idx = 0; idx < dim; ++idx) {
dist += (vec1[idx] - vec2[idx]) * (vec1[idx] - vec2[idx]);
}
return dist;
}
/*! \brief Insert a new element into a heap */
template <typename FloatType, typename IdType>
void HeapInsert(IdType* out, FloatType* dist,
IdType new_id, FloatType new_dist,
int k, bool check_repeat = false) {
if (new_dist > dist[0]) return;
// check if we have it
if (check_repeat) {
for (IdType i = 0; i < k; ++i) {
if (out[i] == new_id) return;
}
}
IdType left_idx = 0, right_idx = 0, curr_idx = 0, swap_idx = 0;
dist[0] = new_dist;
out[0] = new_id;
while (true) {
left_idx = 2 * curr_idx + 1;
right_idx = left_idx + 1;
swap_idx = curr_idx;
if (left_idx < k && dist[left_idx] > dist[swap_idx]) {
swap_idx = left_idx;
}
if (right_idx < k && dist[right_idx] > dist[swap_idx]) {
swap_idx = right_idx;
}
if (swap_idx != curr_idx) {
std::swap(dist[curr_idx], dist[swap_idx]);
std::swap(out[curr_idx], out[swap_idx]);
curr_idx = swap_idx;
} else {
break;
}
}
}
/*! \brief Insert a new element and its flag into heap, return 1 if insert successfully */
template <typename FloatType, typename IdType>
int FlaggedHeapInsert(IdType* out, FloatType* dist, bool* flag,
IdType new_id, FloatType new_dist, bool new_flag,
int k, bool check_repeat = false) {
if (new_dist > dist[0]) return 0;
if (check_repeat) {
for (IdType i = 0; i < k; ++i) {
if (out[i] == new_id) return 0;
}
}
IdType left_idx = 0, right_idx = 0, curr_idx = 0, swap_idx = 0;
dist[0] = new_dist;
out[0] = new_id;
flag[0] = new_flag;
while (true) {
left_idx = 2 * curr_idx + 1;
right_idx = left_idx + 1;
swap_idx = curr_idx;
if (left_idx < k && dist[left_idx] > dist[swap_idx]) {
swap_idx = left_idx;
}
if (right_idx < k && dist[right_idx] > dist[swap_idx]) {
swap_idx = right_idx;
}
if (swap_idx != curr_idx) {
std::swap(dist[curr_idx], dist[swap_idx]);
std::swap(out[curr_idx], out[swap_idx]);
std::swap(flag[curr_idx], flag[swap_idx]);
curr_idx = swap_idx;
} else {
break;
}
}
return 1;
}
/*! \brief Build heap for each point. Used by NN-descent */
template <typename FloatType, typename IdType>
void BuildHeap(IdType* index, FloatType* dist, int k) {
for (int i = k / 2 - 1; i >= 0; --i) {
IdType idx = i;
while (true) {
IdType largest = idx;
IdType left = idx * 2 + 1;
IdType right = left + 1;
if (left < k && dist[left] > dist[largest]) {
largest = left;
}
if (right < k && dist[right] > dist[largest]) {
largest = right;
}
if (largest != idx) {
std::swap(index[largest], index[idx]);
std::swap(dist[largest], dist[idx]);
idx = largest;
} else {
break;
}
}
}
}
/*!
* \brief Neighbor update process in NN-descent. The distance between
* two points are computed. If this new distance is less than any worst
* distance of these two points, we update the neighborhood of that point.
*/
template <typename FloatType, typename IdType>
int UpdateNeighbors(IdType* neighbors, FloatType* dists, const FloatType* points,
bool* flags, IdType c1, IdType c2, IdType point_start,
int64_t feature_size, int k) {
IdType c1_local = c1 - point_start, c2_local = c2 - point_start;
FloatType worst_c1_dist = dists[c1_local * k];
FloatType worst_c2_dist = dists[c2_local * k];
FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
points + c1 * feature_size,
points + c2 * feature_size,
feature_size, std::max(worst_c1_dist, worst_c2_dist));
int num_updates = 0;
if (new_dist < worst_c1_dist) {
++num_updates;
#pragma omp critical
{
FlaggedHeapInsert<FloatType, IdType>(
neighbors + c1 * k,
dists + c1_local * k,
flags + c1_local * k,
c2, new_dist, true, k, true);
}
}
if (new_dist < worst_c2_dist) {
++num_updates;
#pragma omp critical
{
FlaggedHeapInsert<FloatType, IdType>(
neighbors + c2 * k,
dists + c2_local * k,
flags + c2_local * k,
c1, new_dist, true, k, true);
}
}
return num_updates;
}
/*! \brief The kd-tree implementation of K-Nearest Neighbors */
template <typename FloatType, typename IdType>
void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets,
......@@ -61,42 +252,6 @@ void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets,
}
}
template <typename FloatType, typename IdType>
void HeapInsert(IdType* out, FloatType* dist,
IdType new_id, FloatType new_dist,
int k) {
// we assume new distance <= worst distance
IdType left_idx = 0, right_idx = 0, curr_idx = 0, swap_idx = 0;
while (true) {
left_idx = 2 * curr_idx + 1;
right_idx = left_idx + 1;
if (left_idx >= k) {
break;
} else if (right_idx >= k) {
if (dist[left_idx] > new_dist) {
swap_idx = left_idx;
} else {
break;
}
} else {
if (dist[left_idx] > new_dist && dist[left_idx] > dist[right_idx]) {
swap_idx = left_idx;
} else if (dist[right_idx] > new_dist) {
swap_idx = right_idx;
} else {
break;
}
}
dist[curr_idx] = dist[swap_idx];
out[curr_idx] = out[swap_idx];
curr_idx = swap_idx;
}
dist[curr_idx] = new_dist;
out[curr_idx] = new_id;
}
template <typename FloatType, typename IdType>
void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
......@@ -125,43 +280,15 @@ void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets,
FloatType worst_dist = std::numeric_limits<FloatType>::max();
for (IdType d_idx = d_start; d_idx < d_end; ++d_idx) {
FloatType tmp_dist = 0;
bool early_stop = false;
// expand loop (x4)
IdType dim_idx = 0;
while (dim_idx < feature_size - 3) {
const FloatType diff0 = query_points_data[q_idx * feature_size + dim_idx]
- data_points_data[d_idx * feature_size + dim_idx];
const FloatType diff1 = query_points_data[q_idx * feature_size + dim_idx + 1]
- data_points_data[d_idx * feature_size + dim_idx + 1];
const FloatType diff2 = query_points_data[q_idx * feature_size + dim_idx + 2]
- data_points_data[d_idx * feature_size + dim_idx + 2];
const FloatType diff3 = query_points_data[q_idx * feature_size + dim_idx + 3]
- data_points_data[d_idx * feature_size + dim_idx + 3];
tmp_dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
dim_idx += 4;
if (tmp_dist > worst_dist) {
early_stop = true;
dim_idx = feature_size;
break;
}
}
FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>(
query_points_data + q_idx * feature_size,
data_points_data + d_idx * feature_size,
feature_size, worst_dist);
// last 3 elements
while (dim_idx < feature_size) {
const FloatType diff = query_points_data[q_idx * feature_size + dim_idx]
- data_points_data[d_idx * feature_size + dim_idx];
tmp_dist += diff * diff;
++dim_idx;
if (tmp_dist > worst_dist) {
early_stop = true;
break;
}
if (tmp_dist == std::numeric_limits<FloatType>::max()) {
continue;
}
if (early_stop) continue;
IdType out_offset = q_idx * k;
HeapInsert<FloatType, IdType>(
data_out + out_offset, dist_buffer.data(), d_idx, tmp_dist, k);
......@@ -170,7 +297,6 @@ void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets,
}
}
}
} // namespace impl
template <DLDeviceType XPU, typename FloatType, typename IdType>
......@@ -188,6 +314,250 @@ void KNN(const NDArray& data_points, const IdArray& data_offsets,
}
}
template <DLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta) {
using nnd_updates_t = std::vector<std::vector<std::tuple<IdType, IdType, FloatType>>>;
const auto& ctx = points->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t num_nodes = points->shape[0];
const int64_t batch_size = offsets->shape[0] - 1;
const int64_t feature_size = points->shape[1];
const IdType* offsets_data = offsets.Ptr<IdType>();
const FloatType* points_data = points.Ptr<FloatType>();
IdType* central_nodes = result.Ptr<IdType>();
IdType* neighbors = central_nodes + k * num_nodes;
int64_t max_segment_size = 0;
// find max segment
for (IdType b = 0; b < batch_size; ++b) {
if (max_segment_size < offsets_data[b + 1] - offsets_data[b])
max_segment_size = offsets_data[b + 1] - offsets_data[b];
}
// allocate memory for candidate, sampling pool, distance and flag
IdType* new_candidates = static_cast<IdType*>(
device->AllocWorkspace(ctx, max_segment_size * num_candidates * sizeof(IdType)));
IdType* old_candidates = static_cast<IdType*>(
device->AllocWorkspace(ctx, max_segment_size * num_candidates * sizeof(IdType)));
FloatType* new_candidates_dists = static_cast<FloatType*>(
device->AllocWorkspace(ctx, max_segment_size * num_candidates * sizeof(FloatType)));
FloatType* old_candidates_dists = static_cast<FloatType*>(
device->AllocWorkspace(ctx, max_segment_size * num_candidates * sizeof(FloatType)));
FloatType* neighbors_dists = static_cast<FloatType*>(
device->AllocWorkspace(ctx, max_segment_size * k * sizeof(FloatType)));
bool* flags = static_cast<bool*>(
device->AllocWorkspace(ctx, max_segment_size * k * sizeof(bool)));
for (IdType b = 0; b < batch_size; ++b) {
IdType point_idx_start = offsets_data[b], point_idx_end = offsets_data[b + 1];
IdType segment_size = point_idx_end - point_idx_start;
// random initialization
#pragma omp parallel for
for (IdType i = point_idx_start; i < point_idx_end; ++i) {
IdType local_idx = i - point_idx_start;
dgl::RandomEngine::ThreadLocal()->UniformChoice<IdType>(
k, segment_size, neighbors + i * k, false);
for (IdType n = 0; n < k; ++n) {
central_nodes[i * k + n] = i;
neighbors[i * k + n] += point_idx_start;
flags[local_idx * k + n] = true;
neighbors_dists[local_idx * k + n] = impl::EuclideanDist<FloatType, IdType>(
points_data + i * feature_size,
points_data + neighbors[i * k + n] * feature_size,
feature_size);
}
impl::BuildHeap<FloatType, IdType>(neighbors + i * k, neighbors_dists + local_idx * k, k);
}
size_t num_updates = 0;
for (int iter = 0; iter < num_iters; ++iter) {
num_updates = 0;
// initialize candidates array as empty value
#pragma omp parallel for
for (IdType i = point_idx_start; i < point_idx_end; ++i) {
IdType local_idx = i - point_idx_start;
for (IdType c = 0; c < num_candidates; ++c) {
new_candidates[local_idx * num_candidates + c] = num_nodes;
old_candidates[local_idx * num_candidates + c] = num_nodes;
new_candidates_dists[local_idx * num_candidates + c] =
std::numeric_limits<FloatType>::max();
old_candidates_dists[local_idx * num_candidates + c] =
std::numeric_limits<FloatType>::max();
}
}
// randomly select neighbors as candidates
int tid, num_threads;
#pragma omp parallel private(tid, num_threads)
{
tid = omp_get_thread_num();
num_threads = omp_get_num_threads();
for (IdType i = point_idx_start; i < point_idx_end; ++i) {
IdType local_idx = i - point_idx_start;
for (IdType n = 0; n < k; ++n) {
IdType neighbor_idx = neighbors[i * k + n];
bool is_new = flags[local_idx * k + n];
IdType local_neighbor_idx = neighbor_idx - point_idx_start;
FloatType random_dist = dgl::RandomEngine::ThreadLocal()->Uniform<FloatType>();
if (is_new) {
if (local_idx % num_threads == tid) {
impl::HeapInsert<FloatType, IdType>(
new_candidates + local_idx * num_candidates,
new_candidates_dists + local_idx * num_candidates,
neighbor_idx, random_dist, num_candidates, true);
}
if (local_neighbor_idx % num_threads == tid) {
impl::HeapInsert<FloatType, IdType>(
new_candidates + local_neighbor_idx * num_candidates,
new_candidates_dists + local_neighbor_idx * num_candidates,
i, random_dist, num_candidates, true);
}
} else {
if (local_idx % num_threads == tid) {
impl::HeapInsert<FloatType, IdType>(
old_candidates + local_idx * num_candidates,
old_candidates_dists + local_idx * num_candidates,
neighbor_idx, random_dist, num_candidates, true);
}
if (local_neighbor_idx % num_threads == tid) {
impl::HeapInsert<FloatType, IdType>(
old_candidates + local_neighbor_idx * num_candidates,
old_candidates_dists + local_neighbor_idx * num_candidates,
i, random_dist, num_candidates, true);
}
}
}
}
}
// mark all elements in new_candidates as false
#pragma omp parallel for
for (IdType i = point_idx_start; i < point_idx_end; ++i) {
IdType local_idx = i - point_idx_start;
for (IdType n = 0; n < k; ++n) {
IdType n_idx = neighbors[i * k + n];
for (IdType c = 0; c < num_candidates; ++c) {
if (new_candidates[local_idx * num_candidates + c] == n_idx) {
flags[local_idx * k + n] = false;
break;
}
}
}
}
// update neighbors block by block
for (IdType block_start = point_idx_start;
block_start < point_idx_end;
block_start += impl::NN_DESCENT_BLOCK_SIZE) {
IdType block_end = std::min(point_idx_end, block_start + impl::NN_DESCENT_BLOCK_SIZE);
IdType block_size = block_end - block_start;
nnd_updates_t updates(block_size);
// generate updates
#pragma omp parallel for
for (IdType i = block_start; i < block_end; ++i) {
IdType local_idx = i - point_idx_start;
for (IdType c1 = 0; c1 < num_candidates; ++c1) {
IdType new_c1 = new_candidates[local_idx * num_candidates + c1];
if (new_c1 == num_nodes) continue;
IdType c1_local = new_c1 - point_idx_start;
// new-new
for (IdType c2 = c1; c2 < num_candidates; ++c2) {
IdType new_c2 = new_candidates[local_idx * num_candidates + c2];
if (new_c2 == num_nodes) continue;
IdType c2_local = new_c2 - point_idx_start;
FloatType worst_c1_dist = neighbors_dists[c1_local * k];
FloatType worst_c2_dist = neighbors_dists[c2_local * k];
FloatType new_dist = impl::EuclideanDistWithCheck<FloatType, IdType>(
points_data + new_c1 * feature_size,
points_data + new_c2 * feature_size,
feature_size,
std::max(worst_c1_dist, worst_c2_dist));
if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) {
updates[i - block_start].push_back(std::make_tuple(new_c1, new_c2, new_dist));
}
}
// new-old
for (IdType c2 = 0; c2 < num_candidates; ++c2) {
IdType old_c2 = old_candidates[local_idx * num_candidates + c2];
if (old_c2 == num_nodes) continue;
IdType c2_local = old_c2 - point_idx_start;
FloatType worst_c1_dist = neighbors_dists[c1_local * k];
FloatType worst_c2_dist = neighbors_dists[c2_local * k];
FloatType new_dist = impl::EuclideanDistWithCheck<FloatType, IdType>(
points_data + new_c1 * feature_size,
points_data + old_c2 * feature_size,
feature_size,
std::max(worst_c1_dist, worst_c2_dist));
if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) {
updates[i - block_start].push_back(std::make_tuple(new_c1, old_c2, new_dist));
}
}
}
}
#pragma omp parallel private(tid, num_threads) reduction(+:num_updates)
{
tid = omp_get_thread_num();
num_threads = omp_get_num_threads();
for (IdType i = 0; i < block_size; ++i) {
for (const auto & u : updates[i]) {
IdType p1, p2;
FloatType d;
std::tie(p1, p2, d) = u;
IdType p1_local = p1 - point_idx_start;
IdType p2_local = p2 - point_idx_start;
if (p1 % num_threads == tid) {
num_updates += impl::FlaggedHeapInsert<FloatType, IdType>(
neighbors + p1 * k,
neighbors_dists + p1_local * k,
flags + p1_local * k,
p2, d, true, k, true);
}
if (p2 % num_threads == tid) {
num_updates += impl::FlaggedHeapInsert<FloatType, IdType>(
neighbors + p2 * k,
neighbors_dists + p2_local * k,
flags + p2_local * k,
p1, d, true, k, true);
}
}
}
}
}
// early abort
if (num_updates <= static_cast<size_t>(delta * k * segment_size)) {
break;
}
}
}
device->FreeWorkspace(ctx, new_candidates);
device->FreeWorkspace(ctx, old_candidates);
device->FreeWorkspace(ctx, new_candidates_dists);
device->FreeWorkspace(ctx, old_candidates_dists);
device->FreeWorkspace(ctx, neighbors_dists);
device->FreeWorkspace(ctx, flags);
}
template void KNN<kDLCPU, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
......@@ -204,5 +574,22 @@ template void KNN<kDLCPU, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void NNDescent<kDLCPU, float, int32_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLCPU, float, int64_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLCPU, double, int32_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLCPU, double, int64_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
} // namespace transform
} // namespace dgl
......@@ -5,7 +5,9 @@
*/
#include <dgl/array.h>
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <algorithm>
#include <string>
#include <vector>
......@@ -50,17 +52,204 @@ struct SharedMemory<double> {
}
};
/*! \brief Compute Euclidean distance between two vectors in a cuda kernel */
template <typename FloatType, typename IdType>
__device__ FloatType EuclideanDist(const FloatType* vec1,
const FloatType* vec2,
const int64_t dim) {
FloatType dist = 0;
IdType idx = 0;
for (; idx < dim - 3; idx += 4) {
FloatType diff0 = vec1[idx] - vec2[idx];
FloatType diff1 = vec1[idx + 1] - vec2[idx + 1];
FloatType diff2 = vec1[idx + 2] - vec2[idx + 2];
FloatType diff3 = vec1[idx + 3] - vec2[idx + 3];
dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
}
for (; idx < dim; ++idx) {
FloatType diff = vec1[idx] - vec2[idx];
dist += diff * diff;
}
return dist;
}
/*!
* \brief Compute Euclidean distance between two vectors in a cuda kernel,
* return positive infinite value if the intermediate distance is greater
* than the worst distance.
*/
template <typename FloatType, typename IdType>
__device__ FloatType EuclideanDistWithCheck(const FloatType* vec1,
const FloatType* vec2,
const int64_t dim,
const FloatType worst_dist) {
FloatType dist = 0;
IdType idx = 0;
bool early_stop = false;
for (; idx < dim - 3; idx += 4) {
FloatType diff0 = vec1[idx] - vec2[idx];
FloatType diff1 = vec1[idx + 1] - vec2[idx + 1];
FloatType diff2 = vec1[idx + 2] - vec2[idx + 2];
FloatType diff3 = vec1[idx + 3] - vec2[idx + 3];
dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
if (dist > worst_dist) {
early_stop = true;
idx = dim;
break;
}
}
for (; idx < dim; ++idx) {
FloatType diff = vec1[idx] - vec2[idx];
dist += diff * diff;
if (dist > worst_dist) {
early_stop = true;
break;
}
}
if (early_stop) {
return std::numeric_limits<FloatType>::max();
} else {
return dist;
}
}
template <typename FloatType, typename IdType>
__device__ void BuildHeap(IdType* indices, FloatType* dists, int size) {
for (int i = size / 2 - 1; i >= 0; --i) {
IdType idx = i;
while (true) {
IdType largest = idx;
IdType left = idx * 2 + 1;
IdType right = left + 1;
if (left < size && dists[left] > dists[largest]) {
largest = left;
}
if (right < size && dists[right] > dists[largest]) {
largest = right;
}
if (largest != idx) {
IdType tmp_idx = indices[largest];
indices[largest] = indices[idx];
indices[idx] = tmp_idx;
FloatType tmp_dist = dists[largest];
dists[largest] = dists[idx];
dists[idx] = tmp_dist;
idx = largest;
} else {
break;
}
}
}
}
template <typename FloatType, typename IdType>
__device__ void HeapInsert(IdType* indices, FloatType* dist,
IdType new_idx, FloatType new_dist,
int size, bool check_repeat = false) {
if (new_dist > dist[0]) return;
// check if we have it
if (check_repeat) {
for (IdType i = 0; i < size; ++i) {
if (indices[i] == new_idx) return;
}
}
IdType left = 0, right = 0, idx = 0, largest = 0;
dist[0] = new_dist;
indices[0] = new_idx;
while (true) {
left = idx * 2 + 1;
right = left + 1;
if (left < size && dist[left] > dist[largest]) {
largest = left;
}
if (right < size && dist[right] > dist[largest]) {
largest = right;
}
if (largest != idx) {
IdType tmp_idx = indices[idx];
indices[idx] = indices[largest];
indices[largest] = tmp_idx;
FloatType tmp_dist = dist[idx];
dist[idx] = dist[largest];
dist[largest] = tmp_dist;
idx = largest;
} else {
break;
}
}
}
template <typename FloatType, typename IdType>
__device__ bool FlaggedHeapInsert(IdType* indices, FloatType* dist, bool* flags,
IdType new_idx, FloatType new_dist, bool new_flag,
int size, bool check_repeat = false) {
if (new_dist > dist[0]) return false;
// check if we have it
if (check_repeat) {
for (IdType i = 0; i < size; ++i) {
if (indices[i] == new_idx) return false;
}
}
IdType left = 0, right = 0, idx = 0, largest = 0;
dist[0] = new_dist;
indices[0] = new_idx;
flags[0] = new_flag;
while (true) {
left = idx * 2 + 1;
right = left + 1;
if (left < size && dist[left] > dist[largest]) {
largest = left;
}
if (right < size && dist[right] > dist[largest]) {
largest = right;
}
if (largest != idx) {
IdType tmp_idx = indices[idx];
indices[idx] = indices[largest];
indices[largest] = tmp_idx;
FloatType tmp_dist = dist[idx];
dist[idx] = dist[largest];
dist[largest] = tmp_dist;
bool tmp_flag = flags[idx];
flags[idx] = flags[largest];
flags[largest] = tmp_flag;
idx = largest;
} else {
break;
}
}
return true;
}
/*!
* \brief Brute force kNN kernel. Compute distance for each pair of input points and get
* the result directly (without a distance matrix).
*/
template <typename FloatType, typename IdType>
__global__ void bruteforce_knn_kernel(const FloatType* data_points, const IdType* data_offsets,
const FloatType* query_points, const IdType* query_offsets,
const int k, FloatType* dists, IdType* query_out,
IdType* data_out, const int64_t num_batches,
const int64_t feature_size) {
__global__ void BruteforceKnnKernel(const FloatType* data_points, const IdType* data_offsets,
const FloatType* query_points, const IdType* query_offsets,
const int k, FloatType* dists, IdType* query_out,
IdType* data_out, const int64_t num_batches,
const int64_t feature_size) {
const IdType q_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (q_idx >= query_offsets[num_batches]) return;
IdType batch_idx = 0;
for (IdType b = 0; b < num_batches + 1; ++b) {
if (query_offsets[b] > q_idx) { batch_idx = b - 1; break; }
......@@ -74,84 +263,41 @@ __global__ void bruteforce_knn_kernel(const FloatType* data_points, const IdType
FloatType worst_dist = std::numeric_limits<FloatType>::max();
for (IdType d_idx = data_start; d_idx < data_end; ++d_idx) {
FloatType tmp_dist = 0;
IdType dim_idx = 0;
bool early_stop = false;
// expand loop (x4), #pragma unroll has poor performance here
for (; dim_idx < feature_size - 3; dim_idx += 4) {
FloatType diff0 = query_points[q_idx * feature_size + dim_idx]
- data_points[d_idx * feature_size + dim_idx];
FloatType diff1 = query_points[q_idx * feature_size + dim_idx + 1]
- data_points[d_idx * feature_size + dim_idx + 1];
FloatType diff2 = query_points[q_idx * feature_size + dim_idx + 2]
- data_points[d_idx * feature_size + dim_idx + 2];
FloatType diff3 = query_points[q_idx * feature_size + dim_idx + 3]
- data_points[d_idx * feature_size + dim_idx + 3];
tmp_dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
// stop if current distance > all top-k distances.
if (tmp_dist > worst_dist) {
early_stop = true;
dim_idx = feature_size;
break;
}
}
// last few elements
for (; dim_idx < feature_size; ++dim_idx) {
FloatType diff = query_points[q_idx * feature_size + dim_idx]
- data_points[d_idx * feature_size + dim_idx];
tmp_dist += diff * diff;
if (tmp_dist > worst_dist) {
early_stop = true;
break;
}
}
FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>(
query_points + q_idx * feature_size,
data_points + d_idx * feature_size,
feature_size, worst_dist);
if (early_stop) continue;
// maintain a monotonic array by "insert sort"
IdType out_offset = q_idx * k;
for (IdType k1 = 0; k1 < k; ++k1) {
if (dists[out_offset + k1] > tmp_dist) {
for (IdType k2 = k - 1; k2 > k1; --k2) {
dists[out_offset + k2] = dists[out_offset + k2 - 1];
data_out[out_offset + k2] = data_out[out_offset + k2 - 1];
}
dists[out_offset + k1] = tmp_dist;
data_out[out_offset + k1] = d_idx;
worst_dist = dists[out_offset + k - 1];
break;
}
}
HeapInsert<FloatType, IdType>(data_out + out_offset, dists + out_offset, d_idx, tmp_dist, k);
worst_dist = dists[q_idx * k];
}
}
/*!
* \brief Same as bruteforce_knn_kernel, but use shared memory as buffer.
* \brief Same as BruteforceKnnKernel, but use shared memory as buffer.
* This kernel divides query points and data points into blocks. For each
* query block, it will make a loop over all data blocks and compute distances.
* This kernel is faster when the dimension of input points is not large.
*/
template <typename FloatType, typename IdType>
__global__ void bruteforce_knn_share_kernel(const FloatType* data_points,
const IdType* data_offsets,
const FloatType* query_points,
const IdType* query_offsets,
const IdType* block_batch_id,
const IdType* local_block_id,
const int k, FloatType* dists,
IdType* query_out, IdType* data_out,
const int64_t num_batches,
const int64_t feature_size) {
__global__ void BruteforceKnnShareKernel(const FloatType* data_points,
const IdType* data_offsets,
const FloatType* query_points,
const IdType* query_offsets,
const IdType* block_batch_id,
const IdType* local_block_id,
const int k, FloatType* dists,
IdType* query_out, IdType* data_out,
const int64_t num_batches,
const int64_t feature_size) {
const IdType block_idx = static_cast<IdType>(blockIdx.x);
const IdType block_size = static_cast<IdType>(blockDim.x);
const IdType batch_idx = block_batch_id[block_idx];
const IdType local_bid = local_block_id[block_idx];
const IdType query_start = query_offsets[batch_idx] + block_size * local_bid;
const IdType query_end = min(query_start + block_size, query_offsets[batch_idx + 1]);
if (query_start >= query_end) return;
const IdType query_idx = query_start + threadIdx.x;
const IdType data_start = data_offsets[batch_idx];
const IdType data_end = data_offsets[batch_idx + 1];
......@@ -227,18 +373,10 @@ __global__ void bruteforce_knn_share_kernel(const FloatType* data_points,
if (early_stop) continue;
for (IdType k1 = 0; k1 < k; ++k1) {
if (dist_buff[threadIdx.x * k + k1] > tmp_dist) {
for (IdType k2 = k - 1; k2 > k1; --k2) {
dist_buff[threadIdx.x * k + k2] = dist_buff[threadIdx.x * k + k2 - 1];
res_buff[threadIdx.x * k + k2] = res_buff[threadIdx.x * k + k2 - 1];
}
dist_buff[threadIdx.x * k + k1] = tmp_dist;
res_buff[threadIdx.x * k + k1] = d_idx + tile_start;
worst_dist = dist_buff[threadIdx.x * k + k - 1];
break;
}
}
HeapInsert<FloatType, IdType>(
res_buff + threadIdx.x * k, dist_buff + threadIdx.x * k,
d_idx + tile_start, tmp_dist, k);
worst_dist = dist_buff[threadIdx.x * k];
}
}
}
......@@ -255,9 +393,9 @@ __global__ void bruteforce_knn_share_kernel(const FloatType* data_points,
/*! \brief determine the number of blocks for each segment */
template <typename IdType>
__global__ void get_num_block_per_segment(const IdType* offsets, IdType* out,
const int64_t batch_size,
const int64_t block_size) {
__global__ void GetNumBlockPerSegment(const IdType* offsets, IdType* out,
const int64_t batch_size,
const int64_t block_size) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < batch_size) {
out[idx] = (offsets[idx + 1] - offsets[idx] - 1) / block_size + 1;
......@@ -266,9 +404,9 @@ __global__ void get_num_block_per_segment(const IdType* offsets, IdType* out,
/*! \brief Get the batch index and local index in segment for each block */
template <typename IdType>
__global__ void get_block_info(const IdType* num_block_prefixsum,
IdType* block_batch_id, IdType* local_block_id,
size_t batch_size, size_t num_blocks) {
__global__ void GetBlockInfo(const IdType* num_block_prefixsum,
IdType* block_batch_id, IdType* local_block_id,
size_t batch_size, size_t num_blocks) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
IdType i = 0;
......@@ -316,7 +454,7 @@ void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets,
const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]);
const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1;
CUDA_KERNEL_CALL(bruteforce_knn_kernel, num_blocks, block_size, 0, thr_entry->stream,
CUDA_KERNEL_CALL(BruteforceKnnKernel, num_blocks, block_size, 0, thr_entry->stream,
data_points_data, data_offsets_data, query_points_data, query_offsets_data,
k, dists, query_out, data_out, batch_size, feature_size);
......@@ -370,10 +508,10 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
IdType* num_block_prefixsum = static_cast<IdType*>(
device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));
// block size for get_num_block_per_segment computation
// block size for GetNumBlockPerSegment computation
int64_t temp_block_size = cuda::FindNumThreads(batch_size);
int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1;
CUDA_KERNEL_CALL(get_num_block_per_segment, temp_num_blocks,
CUDA_KERNEL_CALL(GetNumBlockPerSegment, temp_num_blocks,
temp_block_size, 0, thr_entry->stream,
query_offsets_data, num_block_per_segment,
batch_size, block_size);
......@@ -408,13 +546,13 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
IdType* local_block_id = static_cast<IdType*>(device->AllocWorkspace(
ctx, num_blocks * sizeof(IdType)));
CUDA_KERNEL_CALL(
get_block_info, temp_num_blocks, temp_block_size, 0,
GetBlockInfo, temp_num_blocks, temp_block_size, 0,
thr_entry->stream, num_block_prefixsum, block_batch_id,
local_block_id, batch_size, num_blocks);
FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
ctx, k * query_points->shape[0] * sizeof(FloatType)));
CUDA_KERNEL_CALL(bruteforce_knn_share_kernel, num_blocks, block_size,
CUDA_KERNEL_CALL(BruteforceKnnShareKernel, num_blocks, block_size,
single_shared_mem * block_size, thr_entry->stream, data_points_data,
data_offsets_data, query_points_data, query_offsets_data,
block_batch_id, local_block_id, k, dists, query_out,
......@@ -424,6 +562,257 @@ void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_off
device->FreeWorkspace(ctx, local_block_id);
device->FreeWorkspace(ctx, block_batch_id);
}
/*! \brief Setup rng state for nn-descent */
__global__ void SetupRngKernel(curandState* states,
const uint64_t seed,
const size_t n) {
size_t id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < n) {
curand_init(seed, id, 0, states + id);
}
}
/*!
* \brief Randomly initialize neighbors (sampling without replacement)
* for each nodes
*/
template <typename FloatType, typename IdType>
__global__ void RandomInitNeighborsKernel(const FloatType* points,
const IdType* offsets,
IdType* central_nodes,
IdType* neighbors,
FloatType* dists,
bool* flags,
const int k,
const int64_t feature_size,
const int64_t batch_size,
const uint64_t seed) {
const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
IdType batch_idx = 0;
if (point_idx >= offsets[batch_size]) return;
curandState state;
curand_init(seed, point_idx, 0, &state);
// find the segment location in the input batch
for (IdType b = 0; b < batch_size + 1; ++b) {
if (offsets[b] > point_idx) {
batch_idx = b - 1;
break;
}
}
const IdType segment_size = offsets[batch_idx + 1] - offsets[batch_idx];
IdType* current_neighbors = neighbors + point_idx * k;
IdType* current_central_nodes = central_nodes + point_idx * k;
bool* current_flags = flags + point_idx * k;
FloatType* current_dists = dists + point_idx * k;
IdType segment_start = offsets[batch_idx];
// reservoir sampling
for (IdType i = 0; i < k; ++i) {
current_neighbors[i] = i + segment_start;
current_central_nodes[i] = point_idx;
}
for (IdType i = k; i < segment_size; ++i) {
const IdType j = static_cast<IdType>(curand(&state) % (i + 1));
if (j < k) current_neighbors[j] = i + segment_start;
}
// compute distances and set flags
for (IdType i = 0; i < k; ++i) {
current_flags[i] = true;
current_dists[i] = EuclideanDist<FloatType, IdType>(
points + point_idx * feature_size,
points + current_neighbors[i] * feature_size,
feature_size);
}
// build heap
BuildHeap<FloatType, IdType>(neighbors + point_idx * k, current_dists, k);
}
/*! \brief Randomly select candidates from current knn and reverse-knn graph for nn-descent */
template <typename IdType>
__global__ void FindCandidatesKernel(const IdType* offsets, IdType* new_candidates,
IdType* old_candidates, IdType* neighbors, bool* flags,
const uint64_t seed, const int64_t batch_size,
const int num_candidates, const int k) {
const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
IdType batch_idx = 0;
if (point_idx >= offsets[batch_size]) return;
curandState state;
curand_init(seed, point_idx, 0, &state);
// find the segment location in the input batch
for (IdType b = 0; b < batch_size + 1; ++b) {
if (offsets[b] > point_idx) {
batch_idx = b - 1;
break;
}
}
IdType segment_start = offsets[batch_idx], segment_end = offsets[batch_idx + 1];
IdType* current_neighbors = neighbors + point_idx * k;
bool* current_flags = flags + point_idx * k;
// reset candidates
IdType* new_candidates_ptr = new_candidates + point_idx * (num_candidates + 1);
IdType* old_candidates_ptr = old_candidates + point_idx * (num_candidates + 1);
new_candidates_ptr[0] = 0;
old_candidates_ptr[0] = 0;
// select candidates from current knn graph
// here we use candidate[0] for reservoir sampling temporarily
for (IdType i = 0; i < k; ++i) {
IdType candidate = current_neighbors[i];
IdType* candidate_array = current_flags[i] ? new_candidates_ptr : old_candidates_ptr;
IdType curr_num = candidate_array[0];
IdType* candidate_data = candidate_array + 1;
// reservoir sampling
if (curr_num < num_candidates) {
candidate_data[curr_num] = candidate;
} else {
IdType pos = static_cast<IdType>(curand(&state) % (curr_num + 1));
if (pos < num_candidates) candidate_data[pos] = candidate;
}
++candidate_array[0];
}
// select candidates from current reverse knn graph
// here we use candidate[0] for reservoir sampling temporarily
IdType index_start = segment_start * k, index_end = segment_end * k;
for (IdType i = index_start; i < index_end; ++i) {
if (neighbors[i] == point_idx) {
IdType reverse_candidate = (i - index_start) / k + segment_start;
IdType* candidate_array = flags[i] ? new_candidates_ptr : old_candidates_ptr;
IdType curr_num = candidate_array[0];
IdType* candidate_data = candidate_array + 1;
// reservoir sampling
if (curr_num < num_candidates) {
candidate_data[curr_num] = reverse_candidate;
} else {
IdType pos = static_cast<IdType>(curand(&state) % (curr_num + 1));
if (pos < num_candidates) candidate_data[pos] = reverse_candidate;
}
++candidate_array[0];
}
}
// set candidate[0] back to length
if (new_candidates_ptr[0] > num_candidates) new_candidates_ptr[0] = num_candidates;
if (old_candidates_ptr[0] > num_candidates) old_candidates_ptr[0] = num_candidates;
// mark new_candidates as old
IdType num_new_candidates = new_candidates_ptr[0];
for (IdType i = 0; i < k; ++i) {
IdType neighbor_idx = current_neighbors[i];
if (current_flags[i]) {
for (IdType j = 1; j < num_new_candidates + 1; ++j) {
if (new_candidates_ptr[j] == neighbor_idx) {
current_flags[i] = false;
break;
}
}
}
}
}
/*! \brief Update knn graph according to selected candidates for nn-descent */
template <typename FloatType, typename IdType>
__global__ void UpdateNeighborsKernel(const FloatType* points, const IdType* offsets,
IdType* neighbors, IdType* new_candidates,
IdType* old_candidates, FloatType* distances,
bool* flags, IdType* num_updates,
const int64_t batch_size, const int num_candidates,
const int k, const int64_t feature_size) {
const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (point_idx >= offsets[batch_size]) return;
IdType* current_neighbors = neighbors + point_idx * k;
bool* current_flags = flags + point_idx * k;
FloatType* current_dists = distances + point_idx * k;
IdType* new_candidates_ptr = new_candidates + point_idx * (num_candidates + 1);
IdType* old_candidates_ptr = old_candidates + point_idx * (num_candidates + 1);
IdType num_new_candidates = new_candidates_ptr[0];
IdType num_old_candidates = old_candidates_ptr[0];
IdType current_num_updates = 0;
// process new candidates
for (IdType i = 1; i <= num_new_candidates; ++i) {
IdType new_c = new_candidates_ptr[i];
// new/old candidates of the current new candidate
IdType* twohop_new_ptr = new_candidates + new_c * (num_candidates + 1);
IdType* twohop_old_ptr = old_candidates + new_c * (num_candidates + 1);
IdType num_twohop_new = twohop_new_ptr[0];
IdType num_twohop_old = twohop_old_ptr[0];
FloatType worst_dist = current_dists[0];
// new - new
for (IdType j = 1; j <= num_twohop_new; ++j) {
IdType twohop_new_c = twohop_new_ptr[j];
FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
points + point_idx * feature_size,
points + twohop_new_c * feature_size,
feature_size, worst_dist);
if (FlaggedHeapInsert<FloatType, IdType>(
current_neighbors, current_dists, current_flags,
twohop_new_c, new_dist, true, k, true)) {
++current_num_updates;
worst_dist = current_dists[0];
}
}
// new - old
for (IdType j = 1; j <= num_twohop_old; ++j) {
IdType twohop_old_c = twohop_old_ptr[j];
FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
points + point_idx * feature_size,
points + twohop_old_c * feature_size,
feature_size, worst_dist);
if (FlaggedHeapInsert<FloatType, IdType>(
current_neighbors, current_dists, current_flags,
twohop_old_c, new_dist, true, k, true)) {
++current_num_updates;
worst_dist = current_dists[0];
}
}
}
// process old candidates
for (IdType i = 1; i <= num_old_candidates; ++i) {
IdType old_c = old_candidates_ptr[i];
// new candidates of the current old candidate
IdType* twohop_new_ptr = new_candidates + old_c * (num_candidates + 1);
IdType num_twohop_new = twohop_new_ptr[0];
FloatType worst_dist = current_dists[0];
// old - new
for (IdType j = 1; j <= num_twohop_new; ++j) {
IdType twohop_new_c = twohop_new_ptr[j];
FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
points + point_idx * feature_size,
points + twohop_new_c * feature_size,
feature_size, worst_dist);
if (FlaggedHeapInsert<FloatType, IdType>(
current_neighbors, current_dists, current_flags,
twohop_new_c, new_dist, true, k, true)) {
++current_num_updates;
worst_dist = current_dists[0];
}
}
}
num_updates[point_idx] = current_num_updates;
}
} // namespace impl
template <DLDeviceType XPU, typename FloatType, typename IdType>
......@@ -441,6 +830,97 @@ void KNN(const NDArray& data_points, const IdArray& data_offsets,
}
}
template <DLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const auto& ctx = points->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t num_nodes = points->shape[0];
const int64_t feature_size = points->shape[1];
const int64_t batch_size = offsets->shape[0] - 1;
const IdType* offsets_data = offsets.Ptr<IdType>();
const FloatType* points_data = points.Ptr<FloatType>();
IdType* central_nodes = result.Ptr<IdType>();
IdType* neighbors = central_nodes + k * num_nodes;
uint64_t seed;
int warp_size = 0;
CUDA_CALL(cudaDeviceGetAttribute(
&warp_size, cudaDevAttrWarpSize, ctx.device_id));
// We don't need large block sizes, since there's not much inter-thread communication
int64_t block_size = warp_size;
int64_t num_blocks = (num_nodes - 1) / block_size + 1;
// allocate space for candidates, distances and flags
// we use the first element in candidate array to represent length
IdType* new_candidates = static_cast<IdType*>(
device->AllocWorkspace(ctx, num_nodes * (num_candidates + 1) * sizeof(IdType)));
IdType* old_candidates = static_cast<IdType*>(
device->AllocWorkspace(ctx, num_nodes * (num_candidates + 1) * sizeof(IdType)));
IdType* num_updates = static_cast<IdType*>(
device->AllocWorkspace(ctx, num_nodes * sizeof(IdType)));
FloatType* distances = static_cast<FloatType*>(
device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType)));
bool* flags = static_cast<bool*>(
device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType)));
size_t sum_temp_size = 0;
IdType total_num_updates = 0;
IdType* total_num_updates_d = static_cast<IdType*>(
device->AllocWorkspace(ctx, sizeof(IdType)));
CUDA_CALL(cub::DeviceReduce::Sum(
nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes));
IdType* sum_temp_storage = static_cast<IdType*>(
device->AllocWorkspace(ctx, sum_temp_size));
// random initialize neighbors
seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(
std::numeric_limits<uint64_t>::max());
CUDA_KERNEL_CALL(
impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, thr_entry->stream,
points_data, offsets_data, central_nodes, neighbors, distances, flags, k,
feature_size, batch_size, seed);
for (int i = 0; i < num_iters; ++i) {
// select candidates
seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(
std::numeric_limits<uint64_t>::max());
CUDA_KERNEL_CALL(
impl::FindCandidatesKernel, num_blocks, block_size, 0,
thr_entry->stream, offsets_data, new_candidates, old_candidates, neighbors,
flags, seed, batch_size, num_candidates, k);
// update
CUDA_KERNEL_CALL(
impl::UpdateNeighborsKernel, num_blocks, block_size, 0, thr_entry->stream,
points_data, offsets_data, neighbors, new_candidates, old_candidates, distances,
flags, num_updates, batch_size, num_candidates, k, feature_size);
total_num_updates = 0;
CUDA_CALL(cub::DeviceReduce::Sum(
sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d, num_nodes));
device->CopyDataFromTo(
total_num_updates_d, 0, &total_num_updates, 0,
sizeof(IdType), ctx, DLContext{kDLCPU, 0},
offsets->dtype, thr_entry->stream);
if (total_num_updates <= static_cast<IdType>(delta * k * num_nodes)) {
break;
}
}
device->FreeWorkspace(ctx, new_candidates);
device->FreeWorkspace(ctx, old_candidates);
device->FreeWorkspace(ctx, num_updates);
device->FreeWorkspace(ctx, distances);
device->FreeWorkspace(ctx, flags);
device->FreeWorkspace(ctx, total_num_updates_d);
device->FreeWorkspace(ctx, sum_temp_storage);
}
template void KNN<kDLGPU, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
......@@ -458,5 +938,22 @@ template void KNN<kDLGPU, double, int64_t>(
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void NNDescent<kDLGPU, float, int32_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLGPU, float, int64_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLGPU, double, int32_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
template void NNDescent<kDLGPU, double, int64_t>(
const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
} // namespace transform
} // namespace dgl
......@@ -41,5 +41,30 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN")
});
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLNNDescent")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const NDArray points = args[0];
const IdArray offsets = args[1];
const IdArray result = args[2];
const int k = args[3];
const int num_iters = args[4];
const int num_candidates = args[5];
const double delta = args[6];
aten::CheckContiguous(
{points, offsets, result}, {"points", "offsets", "result"});
aten::CheckCtx(
points->ctx, {points, offsets, result}, {"points", "offsets", "result"});
ATEN_XPU_SWITCH_CUDA(points->ctx.device_type, XPU, "NNDescent", {
ATEN_FLOAT_TYPE_SWITCH(points->dtype, FloatType, "points", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
NNDescent<XPU, FloatType, IdType>(
points, offsets, result, k, num_iters, num_candidates, delta);
});
});
});
});
} // namespace transform
} // namespace dgl
......@@ -18,21 +18,37 @@ namespace transform {
* points in the same segment in \a data_points. \a data_offsets and \a query_offsets
* determine the start index of each segment in \a data_points and \a query_points.
*
* \param data_points dataset points
* \param data_offsets offsets of point index in \a data_points
* \param query_points query points
* \param query_offsets offsets of point index in \a query_points
* \param k the number of nearest points
* \param result output array
* \param algorithm algorithm used to compute the k-nearest neighbors
*
* \return A 2D tensor indicating the index relation between \a query_points and \a data_points.
* \param data_points dataset points.
* \param data_offsets offsets of point index in \a data_points.
* \param query_points query points.
* \param query_offsets offsets of point index in \a query_points.
* \param k the number of nearest points.
* \param result output array. A 2D tensor indicating the index
* relation between \a query_points and \a data_points.
* \param algorithm algorithm used to compute the k-nearest neighbors.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
/*!
* \brief For each input point, find \a k approximate nearest points in the same
* segment using NN-descent algorithm.
*
* \param points input points.
* \param offsets offsets of point index.
* \param result output array. A 2D tensor indicating the index relation between points.
* \param k the number of nearest points.
* \param num_iters The maximum number of NN-descent iterations to perform.
* \param num_candidates The maximum number of candidates to be considered during one iteration.
* \param delta A value controls the early abort.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
void NNDescent(const NDArray& points, const IdArray& offsets,
IdArray result, const int k, const int num_iters,
const int num_candidates, const double delta);
} // namespace transform
} // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment