"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ea81a4228d8ff16042c3ccaf61f0e588e60166cd"
Unverified Commit 5d7e80f4 authored by Tianqi Zhang (张天启)'s avatar Tianqi Zhang (张天启) Committed by GitHub
Browse files

[Feature] Add bruteforce implementation for KNN with O(Nk) space complexity (#2892)



* add bruteforce impl

* add support for bruteforce-sharemem

* modify python API

* add tests

* change file path

* change python API

* fix lint

* fix test

* also check worst_dist in the last few dim

* use heap and early-stop on CPU

* fix lint

* fix lint

* add device check

* use cuda function to determine max shared mem

* use cuda to determine block info

* add memory free for tmp var

* update doc-string and add dist option

* fix lint

* add more tests
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent db0fb4ea
......@@ -67,8 +67,8 @@ class KNNGraph(nn.Module):
self.k = k
#pylint: disable=invalid-name
def forward(self, x, algorithm='topk'):
"""
def forward(self, x, algorithm='bruteforce-blas', dist='euclidean'):
r"""
Forward computation.
......@@ -81,18 +81,44 @@ class KNNGraph(nn.Module):
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
* 'topk' will use topk algorithm (quick-select or sorting,
depending on backend implementation)
* 'kd-tree' will use kd-tree algorithm (cpu version)
(default: 'topk')
* 'bruteforce-blas' will first compute the distance matrix
using BLAS matrix multiplication operation provided by
backend frameworks. Then use topk algorithm to get
k-nearest neighbors. This method is fast when the point
set is small but has :math:`O(N^2)` memory complexity where
:math:`N` is the number of points.
* 'bruteforce' will compute distances pair by pair and
directly select the k-nearest neighbors during distance
computation. This method is slower than 'bruteforce-blas'
but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`
is the number of points, :math:`k` is the number of nearest
neighbors per node) since we do not need to store all distances.
* 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'
but use shared memory in CUDA devices for buffer. This method is
faster than 'bruteforce' when the dimension of input points
is not large. This method is only available on CUDA device.
* 'kd-tree' will use the kd-tree algorithm (CPU only).
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm)
:math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
Returns
-------
DGLGraph
A DGLGraph without features.
"""
return knn_graph(x, self.k, algorithm)
return knn_graph(x, self.k, algorithm=algorithm, dist=dist)
class SegmentedKNNGraph(nn.Module):
......@@ -148,7 +174,7 @@ class SegmentedKNNGraph(nn.Module):
self.k = k
#pylint: disable=invalid-name
def forward(self, x, segs, algorithm='topk'):
def forward(self, x, segs, algorithm='bruteforce-blas', dist='euclidean'):
r"""Forward computation.
Parameters
......@@ -163,11 +189,37 @@ class SegmentedKNNGraph(nn.Module):
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
* 'topk' will use topk algorithm (quick-select or sorting,
depending on backend implementation)
* 'kd-tree' will use kd-tree algorithm (cpu version)
(default: 'topk')
* 'bruteforce-blas' will first compute the distance matrix
using BLAS matrix multiplication operation provided by
backend frameworks. Then use topk algorithm to get
k-nearest neighbors. This method is fast when the point
set is small but has :math:`O(N^2)` memory complexity where
:math:`N` is the number of points.
* 'bruteforce' will compute distances pair by pair and
directly select the k-nearest neighbors during distance
computation. This method is slower than 'bruteforce-blas'
but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`
is the number of points, :math:`k` is the number of nearest
neighbors per node) since we do not need to store all distances.
* 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'
but use shared memory in CUDA devices for buffer. This method is
faster than 'bruteforce' when the dimension of input points
is not large. This method is only available on CUDA device.
* 'kd-tree' will use the kd-tree algorithm (CPU only).
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm)
:math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
Returns
-------
......@@ -175,4 +227,4 @@ class SegmentedKNNGraph(nn.Module):
A DGLGraph without features.
"""
return segmented_knn_graph(x, self.k, segs, algorithm)
return segmented_knn_graph(x, self.k, segs, algorithm=algorithm, dist=dist)
......@@ -62,8 +62,8 @@ def pairwise_squared_distance(x):
return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)
#pylint: disable=invalid-name
def knn_graph(x, k, algorithm='topk'):
"""Construct a graph from a set of points according to k-nearest-neighbor (KNN)
def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN)
and return.
The function transforms the coordinates/features of a point set
......@@ -92,20 +92,42 @@ def knn_graph(x, k, algorithm='topk'):
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
* 'topk' will use topk algorithm (quick-select or sorting,
depending on backend implementation)
* 'kd-tree' will use kd-tree algorithm (only on cpu)
(default: 'topk')
* 'bruteforce-blas' will first compute the distance matrix
using BLAS matrix multiplication operation provided by
backend frameworks. Then use topk algorithm to get
k-nearest neighbors. This method is fast when the point
set is small but has :math:`O(N^2)` memory complexity where
:math:`N` is the number of points.
* 'bruteforce' will compute distances pair by pair and
directly select the k-nearest neighbors during distance
computation. This method is slower than 'bruteforce-blas'
but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`
is the number of points, :math:`k` is the number of nearest
neighbors per node) since we do not need to store all distances.
* 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'
but use shared memory in CUDA devices for buffer. This method is
faster than 'bruteforce' when the dimension of input points
is not large. This method is only available on CUDA device.
* 'kd-tree' will use the kd-tree algorithm (CPU only).
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
Returns
-------
DGLGraph
The constructred graph. The node IDs are in the same order as :attr:`x`.
If using the 'topk' algorithm, the returned graph is on the same device as input :attr:`x`.
Else, the returned graph is on CPU, regardless of the context of the input :attr:`x`.
Examples
--------
......@@ -141,8 +163,16 @@ def knn_graph(x, k, algorithm='topk'):
(tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7]),
tensor([0, 1, 1, 2, 3, 0, 2, 3, 4, 5, 6, 7, 4, 6, 5, 7]))
"""
if algorithm == 'topk':
return _knn_graph_topk(x, k)
# check invalid k
if k <= 0:
raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k))
# check empty point set
if F.shape(x)[0] == 0:
raise DGLError("Find empty point set")
if algorithm == 'bruteforce-blas':
return _knn_graph_blas(x, k, dist=dist)
else:
if F.ndim(x) == 3:
x_size = tuple(F.shape(x))
......@@ -150,13 +180,16 @@ def knn_graph(x, k, algorithm='topk'):
x_seg = x_size[0] * [x_size[1]]
else:
x_seg = [F.shape(x)[0]]
out = knn(x, x_seg, x, x_seg, k, algorithm=algorithm)
out = knn(x, x_seg, x, x_seg, k, algorithm=algorithm, dist=dist)
row, col = out[1], out[0]
return convert.graph((row, col))
def _knn_graph_topk(x, k):
"""Construct a graph from a set of points according to k-nearest-neighbor (KNN)
via topk method.
def _knn_graph_blas(x, k, dist='euclidean'):
r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN).
This function first compute the distance matrix using BLAS matrix multiplication
operation provided by backend frameworks. Then use topk algorithm to get
k-nearest neighbors.
Parameters
----------
......@@ -169,11 +202,28 @@ def _knn_graph_topk(x, k):
``x[i][j]`` corresponds to the j-th node in the i-th KNN graph.
k : int
The number of nearest neighbors per node.
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
"""
if F.ndim(x) == 2:
x = F.unsqueeze(x, 0)
n_samples, n_points, _ = F.shape(x)
if k > n_points:
dgl_warning("'k' should be less than or equal to the number of points in 'x'" \
"expect k <= {0}, got k = {1}, use k = {0}".format(n_points, k))
k = n_points
# if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine':
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=2, keepdims=True))
x = x / (l2_norm(x) + 1e-5)
ctx = F.context(x)
dist = pairwise_squared_distance(x)
k_indices = F.argtopk(dist, k, 2, descending=False)
......@@ -187,8 +237,8 @@ def _knn_graph_topk(x, k):
return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
#pylint: disable=invalid-name
def segmented_knn_graph(x, k, segs, algorithm='topk'):
"""Construct multiple graphs from multiple sets of points according to
def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean'):
r"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN) and return.
Compared with :func:`dgl.knn_graph`, this allows multiple point sets with
......@@ -212,20 +262,42 @@ def segmented_knn_graph(x, k, segs, algorithm='topk'):
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
* 'topk' will use topk algorithm (quick-select or sorting,
depending on backend implementation)
* 'kd-tree' will use kd-tree algorithm (only on cpu)
(default: 'topk')
* 'bruteforce-blas' will first compute the distance matrix
using BLAS matrix multiplication operation provided by
backend frameworks. Then use topk algorithm to get
k-nearest neighbors. This method is fast when the point
set is small but has :math:`O(N^2)` memory complexity where
:math:`N` is the number of points.
* 'bruteforce' will compute distances pair by pair and
directly select the k-nearest neighbors during distance
computation. This method is slower than 'bruteforce-blas'
but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`
is the number of points, :math:`k` is the number of nearest
neighbors per node) since we do not need to store all distances.
* 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'
but use shared memory in CUDA devices for buffer. This method is
faster than 'bruteforce' when the dimension of input points
is not large. This method is only available on CUDA device.
* 'kd-tree' will use the kd-tree algorithm (CPU only).
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
(default: 'bruteforce-blas')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
Returns
-------
DGLGraph
The graph. The node IDs are in the same order as :attr:`x`.
If using the 'topk' algorithm, the returned graph is on the same device as input :attr:`x`.
Else, the returned graph is on CPU, regardless of the context of the input :attr:`x`.
Examples
--------
......@@ -253,16 +325,28 @@ def segmented_knn_graph(x, k, segs, algorithm='topk'):
(tensor([0, 0, 1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6]),
tensor([0, 1, 0, 1, 2, 2, 3, 5, 4, 6, 3, 5, 4, 6]))
"""
if algorithm == 'topk':
return _segmented_knn_graph_topk(x, k, segs)
# check invalid k
if k <= 0:
raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k))
# check empty point set
if F.shape(x)[0] == 0:
raise DGLError("Find empty point set")
if algorithm == 'bruteforce-blas':
return _segmented_knn_graph_blas(x, k, segs, dist=dist)
else:
out = knn(x, segs, x, segs, k, algorithm=algorithm)
out = knn(x, segs, x, segs, k, algorithm=algorithm, dist=dist)
row, col = out[1], out[0]
return convert.graph((row, col))
def _segmented_knn_graph_topk(x, k, segs):
"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN) via topk method.
def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'):
r"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN).
This function first compute the distance matrix using BLAS matrix multiplication
operation provided by backend frameworks. Then use topk algorithm to get
k-nearest neighbors.
Parameters
----------
......@@ -273,9 +357,26 @@ def _segmented_knn_graph_topk(x, k, segs):
segs : list[int]
Number of points in each point set. The numbers in :attr:`segs`
must sum up to the number of rows in :attr:`x`.
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
"""
# if use cosine distance, normalize input points first
# thus we can use euclidean distance to find knn equivalently.
if dist == 'cosine':
l2_norm = lambda v: F.sqrt(F.sum(v * v, dim=1, keepdims=True))
x = x / (l2_norm(x) + 1e-5)
n_total_points, _ = F.shape(x)
offset = np.insert(np.cumsum(segs), 0, 0)
min_seg_size = np.min(segs)
if k > min_seg_size:
dgl_warning("'k' should be less than or equal to the number of points in 'x'" \
"expect k <= {0}, got k = {1}, use k = {0}".format(min_seg_size, k))
k = min_seg_size
h_list = F.split(x, segs, 0)
src = [
......@@ -287,7 +388,7 @@ def _segmented_knn_graph_topk(x, k, segs):
dst = F.repeat(F.arange(0, n_total_points, ctx=ctx), k, dim=0)
return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
def knn(x, x_segs, y, y_segs, k, algorithm='kd-tree', dist='euclidean'):
def knn(x, x_segs, y, y_segs, k, algorithm='bruteforce', dist='euclidean'):
r"""For each element in each segment in :attr:`y`, find :attr:`k` nearest
points in the same segment in :attr:`x`.
......@@ -313,14 +414,30 @@ def knn(x, x_segs, y, y_segs, k, algorithm='kd-tree', dist='euclidean'):
The number of nearest neighbors per node.
algorithm : str, optional
Algorithm used to compute the k-nearest neighbors.
Currently only cpu version kdtree is supported.
(default: 'kd-tree')
* 'bruteforce' will compute distances pair by pair and
directly select the k-nearest neighbors during distance
computation. This method is slower than 'bruteforce-blas'
but has less memory overhead (i.e., :math:`O(Nk)` where :math:`N`
is the number of points, :math:`k` is the number of nearest
neighbors per node) since we do not need to store all distances.
* 'bruteforce-sharemem' (CUDA only) is similar to 'bruteforce'
but use shared memory in CUDA devices for buffer. This method is
faster than 'bruteforce' when the dimension of input points
is not large. This method is only available on CUDA device.
* 'kd-tree' will use the kd-tree algorithm (CPU only).
This method is suitable for low-dimensional data (e.g. 3D
point clouds)
(default: 'bruteforce')
dist : str, optional
The distance metric used to compute distance between points. It can be the following
metrics:
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(Default: "euclidean")
(default: 'euclidean')
Returns
-------
......@@ -329,12 +446,7 @@ def knn(x, x_segs, y, y_segs, k, algorithm='kd-tree', dist='euclidean'):
The first subtensor contains point indexs in :attr:`y`. The second subtensor contains
point indexs in :attr:`x`
"""
# currently only cpu implementation is supported.
if (F.context(x) != F.cpu() or F.context(y) != F.cpu()):
dgl_warning("Currently only cpu implementation is supported," \
"copy input tensors to cpu.")
x = F.copy_to(x, F.cpu())
y = F.copy_to(y, F.cpu())
assert F.context(x) == F.context(y)
if isinstance(x_segs, (tuple, list)):
x_segs = F.tensor(x_segs)
if isinstance(y_segs, (tuple, list)):
......@@ -342,16 +454,21 @@ def knn(x, x_segs, y, y_segs, k, algorithm='kd-tree', dist='euclidean'):
x_segs = F.copy_to(x_segs, F.context(x))
y_segs = F.copy_to(y_segs, F.context(y))
# supported algorithms
algorithm_list = ['kd-tree']
if algorithm not in algorithm_list:
raise DGLError("only {} algorithms are supported, get '{}'".format(
algorithm_list, algorithm))
# k shoule be less than or equal to min(x_segs)
min_num_points = F.min(x_segs, dim=0)
if k > min_num_points:
dgl_warning("'k' should be less than or equal to the number of points in 'x'" \
"expect k <= {0}, got k = {1}, use k = {0}".format(min_num_points, k))
k = F.as_scalar(min_num_points)
# invalid k
if k <= 0:
raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k))
# empty point set
if F.shape(x)[0] == 0 or F.shape(y)[0] == 0:
raise DGLError("Find empty point set")
# k must less than or equal to min(x_segs)
if k > F.min(x_segs, dim=0):
raise DGLError("'k' must be less than or equal to the number of points in 'x'"
"expect k <= {}, got k = {}".format(F.min(x_segs, dim=0), k))
dist = dist.lower()
dist_metric_list = ['euclidean', 'cosine']
if dist not in dist_metric_list:
......
/*!
* Copyright (c) 2021 by Contributors
* \file graph/transform/kdtree_ndarray_adapter.h
* \file graph/transform/cpu/kdtree_ndarray_adapter.h
* \brief NDArray adapter for nanoflann, without
* duplicating the storage
*/
#ifndef DGL_GRAPH_TRANSFORM_KDTREE_NDARRAY_ADAPTER_H_
#define DGL_GRAPH_TRANSFORM_KDTREE_NDARRAY_ADAPTER_H_
#ifndef DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_
#define DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_
#include <dgl/array.h>
#include <dmlc/logging.h>
#include <nanoflann.hpp>
#include "../../c_api_common.h"
#include "../../../c_api_common.h"
namespace dgl {
namespace transform {
......@@ -118,4 +118,4 @@ class KDTreeNDArrayAdapter {
} // namespace transform
} // namespace dgl
#endif // DGL_GRAPH_TRANSFORM_KDTREE_NDARRAY_ADAPTER_H_
#endif // DGL_GRAPH_TRANSFORM_CPU_KDTREE_NDARRAY_ADAPTER_H_
/*!
* Copyright (c) 2019 by Contributors
* \file graph/transform/cpu/knn.cc
* \brief k-nearest-neighbor (KNN) implementation
*/
#include <vector>
#include <limits>
#include "kdtree_ndarray_adapter.h"
#include "../knn.h"
using namespace dgl::runtime;
using namespace dgl::transform::knn_utils;
namespace dgl {
namespace transform {
namespace impl {
/*! \brief The kd-tree implementation of K-Nearest Neighbors */
template <typename FloatType, typename IdType>
void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) {
const int64_t batch_size = data_offsets->shape[0] - 1;
const int64_t feature_size = data_points->shape[1];
const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
const FloatType* query_points_data = query_points.Ptr<FloatType>();
IdType* query_out = result.Ptr<IdType>();
IdType* data_out = query_out + k * query_points->shape[0];
for (int64_t b = 0; b < batch_size; ++b) {
auto d_offset = data_offsets_data[b];
auto d_length = data_offsets_data[b + 1] - d_offset;
auto q_offset = query_offsets_data[b];
auto q_length = query_offsets_data[b + 1] - q_offset;
auto out_offset = k * q_offset;
// create view for each segment
const NDArray current_data_points = const_cast<NDArray*>(&data_points)->CreateView(
{d_length, feature_size}, data_points->dtype, d_offset * feature_size * sizeof(FloatType));
const FloatType* current_query_pts_data = query_points_data + q_offset * feature_size;
KDTreeNDArrayAdapter<FloatType, IdType> kdtree(feature_size, current_data_points);
// query
std::vector<IdType> out_buffer(k);
std::vector<FloatType> out_dist_buffer(k);
#pragma omp parallel for firstprivate(out_buffer) firstprivate(out_dist_buffer)
for (IdType q = 0; q < q_length; ++q) {
auto curr_out_offset = k * q + out_offset;
const FloatType* q_point = current_query_pts_data + q * feature_size;
size_t num_matches = kdtree.GetIndex()->knnSearch(
q_point, k, out_buffer.data(), out_dist_buffer.data());
for (size_t i = 0; i < num_matches; ++i) {
query_out[curr_out_offset] = q + q_offset;
data_out[curr_out_offset] = out_buffer[i] + d_offset;
curr_out_offset++;
}
}
}
}
template <typename FloatType, typename IdType>
void HeapInsert(IdType* out, FloatType* dist,
IdType new_id, FloatType new_dist,
int k) {
// we assume new distance <= worst distance
IdType left_idx = 0, right_idx = 0, curr_idx = 0, swap_idx = 0;
while (true) {
left_idx = 2 * curr_idx + 1;
right_idx = left_idx + 1;
if (left_idx >= k) {
break;
} else if (right_idx >= k) {
if (dist[left_idx] > new_dist) {
swap_idx = left_idx;
} else {
break;
}
} else {
if (dist[left_idx] > new_dist && dist[left_idx] > dist[right_idx]) {
swap_idx = left_idx;
} else if (dist[right_idx] > new_dist) {
swap_idx = right_idx;
} else {
break;
}
}
dist[curr_idx] = dist[swap_idx];
out[curr_idx] = out[swap_idx];
curr_idx = swap_idx;
}
dist[curr_idx] = new_dist;
out[curr_idx] = new_id;
}
template <typename FloatType, typename IdType>
void BruteForceKNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) {
const int64_t batch_size = data_offsets->shape[0] - 1;
const int64_t feature_size = data_points->shape[1];
const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
const FloatType* data_points_data = data_points.Ptr<FloatType>();
const FloatType* query_points_data = query_points.Ptr<FloatType>();
IdType* query_out = result.Ptr<IdType>();
IdType* data_out = query_out + k * query_points->shape[0];
for (int64_t b = 0; b < batch_size; ++b) {
IdType d_start = data_offsets_data[b], d_end = data_offsets_data[b + 1];
IdType q_start = query_offsets_data[b], q_end = query_offsets_data[b + 1];
std::vector<FloatType> dist_buffer(k);
#pragma omp parallel for firstprivate(dist_buffer)
for (IdType q_idx = q_start; q_idx < q_end; ++q_idx) {
for (IdType k_idx = 0; k_idx < k; ++k_idx) {
query_out[q_idx * k + k_idx] = q_idx;
dist_buffer[k_idx] = std::numeric_limits<FloatType>::max();
}
FloatType worst_dist = std::numeric_limits<FloatType>::max();
for (IdType d_idx = d_start; d_idx < d_end; ++d_idx) {
FloatType tmp_dist = 0;
bool early_stop = false;
// expand loop (x4)
IdType dim_idx = 0;
while (dim_idx < feature_size - 3) {
const FloatType diff0 = query_points_data[q_idx * feature_size + dim_idx]
- data_points_data[d_idx * feature_size + dim_idx];
const FloatType diff1 = query_points_data[q_idx * feature_size + dim_idx + 1]
- data_points_data[d_idx * feature_size + dim_idx + 1];
const FloatType diff2 = query_points_data[q_idx * feature_size + dim_idx + 2]
- data_points_data[d_idx * feature_size + dim_idx + 2];
const FloatType diff3 = query_points_data[q_idx * feature_size + dim_idx + 3]
- data_points_data[d_idx * feature_size + dim_idx + 3];
tmp_dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
dim_idx += 4;
if (tmp_dist > worst_dist) {
early_stop = true;
dim_idx = feature_size;
break;
}
}
// last 3 elements
while (dim_idx < feature_size) {
const FloatType diff = query_points_data[q_idx * feature_size + dim_idx]
- data_points_data[d_idx * feature_size + dim_idx];
tmp_dist += diff * diff;
++dim_idx;
if (tmp_dist > worst_dist) {
early_stop = true;
break;
}
}
if (early_stop) continue;
IdType out_offset = q_idx * k;
HeapInsert<FloatType, IdType>(
data_out + out_offset, dist_buffer.data(), d_idx, tmp_dist, k);
worst_dist = dist_buffer[0];
}
}
}
}
} // namespace impl
template <DLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm) {
if (algorithm == std::string("kd-tree")) {
impl::KdTreeKNN<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result);
} else if (algorithm == std::string("bruteforce")) {
impl::BruteForceKNN<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result);
} else {
LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CPU";
}
}
template void KNN<kDLCPU, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, float, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, double, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
} // namespace transform
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file graph/transform/cuda/knn.cu
* \brief k-nearest-neighbor (KNN) implementation (cuda)
*/
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include <algorithm>
#include <string>
#include <vector>
#include <limits>
#include "../../../array/cuda/dgl_cub.cuh"
#include "../../../runtime/cuda/cuda_common.h"
#include "../../../array/cuda/utils.h"
#include "../knn.h"
namespace dgl {
namespace transform {
namespace impl {
/*!
* \brief Utility class used to avoid linker errors with extern
* unsized shared memory arrays with templated type
*/
template <typename Type>
struct SharedMemory {
__device__ inline operator Type* () {
extern __shared__ int __smem[];
return reinterpret_cast<Type*>(__smem);
}
__device__ inline operator const Type* () const {
extern __shared__ int __smem[];
return reinterpret_cast<Type*>(__smem);
}
};
// specialize for double to avoid unaligned memory
// access compile errors
template <>
struct SharedMemory<double> {
__device__ inline operator double* () {
extern __shared__ double __smem_d[];
return reinterpret_cast<double*>(__smem_d);
}
__device__ inline operator const double* () const {
extern __shared__ double __smem_d[];
return reinterpret_cast<double*>(__smem_d);
}
};
/*!
* \brief Brute force kNN kernel. Compute distance for each pair of input points and get
* the result directly (without a distance matrix).
*/
template <typename FloatType, typename IdType>
__global__ void bruteforce_knn_kernel(const FloatType* data_points, const IdType* data_offsets,
const FloatType* query_points, const IdType* query_offsets,
const int k, FloatType* dists, IdType* query_out,
IdType* data_out, const int64_t num_batches,
const int64_t feature_size) {
const IdType q_idx = blockIdx.x * blockDim.x + threadIdx.x;
IdType batch_idx = 0;
for (IdType b = 0; b < num_batches + 1; ++b) {
if (query_offsets[b] > q_idx) { batch_idx = b - 1; break; }
}
const IdType data_start = data_offsets[batch_idx], data_end = data_offsets[batch_idx + 1];
for (IdType k_idx = 0; k_idx < k; ++k_idx) {
query_out[q_idx * k + k_idx] = q_idx;
dists[q_idx * k + k_idx] = std::numeric_limits<FloatType>::max();
}
FloatType worst_dist = std::numeric_limits<FloatType>::max();
for (IdType d_idx = data_start; d_idx < data_end; ++d_idx) {
FloatType tmp_dist = 0;
IdType dim_idx = 0;
bool early_stop = false;
// expand loop (x4), #pragma unroll has poor performance here
for (; dim_idx < feature_size - 3; dim_idx += 4) {
FloatType diff0 = query_points[q_idx * feature_size + dim_idx]
- data_points[d_idx * feature_size + dim_idx];
FloatType diff1 = query_points[q_idx * feature_size + dim_idx + 1]
- data_points[d_idx * feature_size + dim_idx + 1];
FloatType diff2 = query_points[q_idx * feature_size + dim_idx + 2]
- data_points[d_idx * feature_size + dim_idx + 2];
FloatType diff3 = query_points[q_idx * feature_size + dim_idx + 3]
- data_points[d_idx * feature_size + dim_idx + 3];
tmp_dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
// stop if current distance > all top-k distances.
if (tmp_dist > worst_dist) {
early_stop = true;
dim_idx = feature_size;
break;
}
}
// last few elements
for (; dim_idx < feature_size; ++dim_idx) {
FloatType diff = query_points[q_idx * feature_size + dim_idx]
- data_points[d_idx * feature_size + dim_idx];
tmp_dist += diff * diff;
if (tmp_dist > worst_dist) {
early_stop = true;
break;
}
}
if (early_stop) continue;
// maintain a monotonic array by "insert sort"
IdType out_offset = q_idx * k;
for (IdType k1 = 0; k1 < k; ++k1) {
if (dists[out_offset + k1] > tmp_dist) {
for (IdType k2 = k - 1; k2 > k1; --k2) {
dists[out_offset + k2] = dists[out_offset + k2 - 1];
data_out[out_offset + k2] = data_out[out_offset + k2 - 1];
}
dists[out_offset + k1] = tmp_dist;
data_out[out_offset + k1] = d_idx;
worst_dist = dists[out_offset + k - 1];
break;
}
}
}
}
/*!
* \brief Same as bruteforce_knn_kernel, but use shared memory as buffer.
* This kernel divides query points and data points into blocks. For each
* query block, it will make a loop over all data blocks and compute distances.
* This kernel is faster when the dimension of input points is not large.
*/
template <typename FloatType, typename IdType>
__global__ void bruteforce_knn_share_kernel(const FloatType* data_points,
const IdType* data_offsets,
const FloatType* query_points,
const IdType* query_offsets,
const IdType* block_batch_id,
const IdType* local_block_id,
const int k, FloatType* dists,
IdType* query_out, IdType* data_out,
const int64_t num_batches,
const int64_t feature_size) {
const IdType block_idx = static_cast<IdType>(blockIdx.x);
const IdType block_size = static_cast<IdType>(blockDim.x);
const IdType batch_idx = block_batch_id[block_idx];
const IdType local_bid = local_block_id[block_idx];
const IdType query_start = query_offsets[batch_idx] + block_size * local_bid;
const IdType query_end = min(query_start + block_size, query_offsets[batch_idx + 1]);
const IdType query_idx = query_start + threadIdx.x;
const IdType data_start = data_offsets[batch_idx];
const IdType data_end = data_offsets[batch_idx + 1];
// shared memory: points in block + distance buffer + result buffer
FloatType* data_buff = SharedMemory<FloatType>();
FloatType* query_buff = data_buff + block_size * feature_size;
FloatType* dist_buff = query_buff + block_size * feature_size;
IdType* res_buff = reinterpret_cast<IdType*>(dist_buff + block_size * k);
FloatType worst_dist = std::numeric_limits<FloatType>::max();
// initialize dist buff with inf value
for (auto i = 0; i < k; ++i) {
dist_buff[threadIdx.x * k + i] = std::numeric_limits<FloatType>::max();
}
// load query data to shared memory
if (query_idx < query_end) {
for (auto i = 0; i < feature_size; ++i) {
// to avoid bank conflict, we use transpose here
query_buff[threadIdx.x + i * block_size] = query_points[query_idx * feature_size + i];
}
}
// perform computation on each tile
for (auto tile_start = data_start; tile_start < data_end; tile_start += block_size) {
// each thread load one data point into the shared memory
IdType load_idx = tile_start + threadIdx.x;
if (load_idx < data_end) {
for (auto i = 0; i < feature_size; ++i) {
data_buff[threadIdx.x * feature_size + i] = data_points[load_idx * feature_size + i];
}
}
__syncthreads();
// compute distance for one tile
IdType true_block_size = min(data_end - tile_start, block_size);
if (query_idx < query_end) {
for (IdType d_idx = 0; d_idx < true_block_size; ++d_idx) {
FloatType tmp_dist = 0;
bool early_stop = false;
IdType dim_idx = 0;
for (; dim_idx < feature_size - 3; dim_idx += 4) {
FloatType diff0 = query_buff[threadIdx.x + block_size * (dim_idx)]
- data_buff[d_idx * feature_size + dim_idx];
FloatType diff1 = query_buff[threadIdx.x + block_size * (dim_idx + 1)]
- data_buff[d_idx * feature_size + dim_idx + 1];
FloatType diff2 = query_buff[threadIdx.x + block_size * (dim_idx + 2)]
- data_buff[d_idx * feature_size + dim_idx + 2];
FloatType diff3 = query_buff[threadIdx.x + block_size * (dim_idx + 3)]
- data_buff[d_idx * feature_size + dim_idx + 3];
tmp_dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
if (tmp_dist > worst_dist) {
early_stop = true;
dim_idx = feature_size;
break;
}
}
for (; dim_idx < feature_size; ++dim_idx) {
const FloatType diff = query_buff[threadIdx.x + dim_idx * block_size]
- data_buff[d_idx * feature_size + dim_idx];
tmp_dist += diff * diff;
if (tmp_dist > worst_dist) {
early_stop = true;
break;
}
}
if (early_stop) continue;
for (IdType k1 = 0; k1 < k; ++k1) {
if (dist_buff[threadIdx.x * k + k1] > tmp_dist) {
for (IdType k2 = k - 1; k2 > k1; --k2) {
dist_buff[threadIdx.x * k + k2] = dist_buff[threadIdx.x * k + k2 - 1];
res_buff[threadIdx.x * k + k2] = res_buff[threadIdx.x * k + k2 - 1];
}
dist_buff[threadIdx.x * k + k1] = tmp_dist;
res_buff[threadIdx.x * k + k1] = d_idx + tile_start;
worst_dist = dist_buff[threadIdx.x * k + k - 1];
break;
}
}
}
}
}
// copy result to global memory
if (query_idx < query_end) {
for (auto i = 0; i < k; ++i) {
dists[query_idx * k + i] = dist_buff[threadIdx.x * k + i];
data_out[query_idx * k + i] = res_buff[threadIdx.x * k + i];
query_out[query_idx * k + i] = query_idx;
}
}
}
/*! \brief determine the number of blocks for each segment */
template <typename IdType>
__global__ void get_num_block_per_segment(const IdType* offsets, IdType* out,
const int64_t batch_size,
const int64_t block_size) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < batch_size) {
out[idx] = (offsets[idx + 1] - offsets[idx] - 1) / block_size + 1;
}
}
/*! \brief Get the batch index and local index in segment for each block */
template <typename IdType>
__global__ void get_block_info(const IdType* num_block_prefixsum,
IdType* block_batch_id, IdType* local_block_id,
size_t batch_size, size_t num_blocks) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
IdType i = 0;
if (idx < num_blocks) {
for (; i < batch_size; ++i) {
if (num_block_prefixsum[i] > idx) break;
}
i--;
block_batch_id[idx] = i;
local_block_id[idx] = idx - num_block_prefixsum[i];
}
}
/*!
* \brief Brute force kNN. Compute distance for each pair of input points and get
* the result directly (without a distance matrix).
*
* \tparam FloatType The type of input points.
* \tparam IdType The type of id.
* \param data_points NDArray of dataset points.
* \param data_offsets offsets of point index in data points.
* \param query_points NDArray of query points
* \param query_offsets offsets of point index in query points.
* \param k the number of nearest points
* \param result output array
*/
template <typename FloatType, typename IdType>
void BruteForceKNNCuda(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const auto& ctx = data_points->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t batch_size = data_offsets->shape[0] - 1;
const int64_t feature_size = data_points->shape[1];
const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
const FloatType* data_points_data = data_points.Ptr<FloatType>();
const FloatType* query_points_data = query_points.Ptr<FloatType>();
IdType* query_out = result.Ptr<IdType>();
IdType* data_out = query_out + k * query_points->shape[0];
FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
ctx, k * query_points->shape[0] * sizeof(FloatType)));
const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]);
const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1;
CUDA_KERNEL_CALL(bruteforce_knn_kernel, num_blocks, block_size, 0, thr_entry->stream,
data_points_data, data_offsets_data, query_points_data, query_offsets_data,
k, dists, query_out, data_out, batch_size, feature_size);
device->FreeWorkspace(ctx, dists);
}
/*!
* \brief Brute force kNN with shared memory.
* This function divides query points and data points into blocks. For each
* query block, it will make a loop over all data blocks and compute distances.
* It will be faster when the dimension of input points is not large.
*
* \tparam FloatType The type of input points.
* \tparam IdType The type of id.
* \param data_points NDArray of dataset points.
* \param data_offsets offsets of point index in data points.
* \param query_points NDArray of query points
* \param query_offsets offsets of point index in query points.
* \param k the number of nearest points
* \param result output array
*/
template <typename FloatType, typename IdType>
void BruteForceKNNSharedCuda(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const auto& ctx = data_points->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t batch_size = data_offsets->shape[0] - 1;
const int64_t feature_size = data_points->shape[1];
const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
const FloatType* data_points_data = data_points.Ptr<FloatType>();
const FloatType* query_points_data = query_points.Ptr<FloatType>();
IdType* query_out = result.Ptr<IdType>();
IdType* data_out = query_out + k * query_points->shape[0];
// get max shared memory per block in bytes
// determine block size according to this value
int max_sharedmem_per_block = 0;
CUDA_CALL(cudaDeviceGetAttribute(
&max_sharedmem_per_block, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));
const int64_t single_shared_mem = (k + 2 * feature_size) * sizeof(FloatType) +
k * sizeof(IdType);
const int64_t block_size = cuda::FindNumThreads(max_sharedmem_per_block / single_shared_mem);
// Determine the number of blocks. We first get the number of blocks for each
// segment. Then we get the block id offset via prefix sum.
IdType* num_block_per_segment = static_cast<IdType*>(
device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));
IdType* num_block_prefixsum = static_cast<IdType*>(
device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));
// block size for get_num_block_per_segment computation
int64_t temp_block_size = cuda::FindNumThreads(batch_size);
int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1;
CUDA_KERNEL_CALL(get_num_block_per_segment, temp_num_blocks,
temp_block_size, 0, thr_entry->stream,
query_offsets_data, num_block_per_segment,
batch_size, block_size);
size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, prefix_temp_size, num_block_per_segment,
num_block_prefixsum, batch_size));
void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_temp, prefix_temp_size, num_block_per_segment,
num_block_prefixsum, batch_size, thr_entry->stream));
device->FreeWorkspace(ctx, prefix_temp);
int64_t num_blocks = 0, final_elem = 0, copyoffset = (batch_size - 1) * sizeof(IdType);
device->CopyDataFromTo(
num_block_prefixsum, copyoffset, &num_blocks, 0,
sizeof(IdType), ctx, DLContext{kDLCPU, 0},
query_offsets->dtype, thr_entry->stream);
device->CopyDataFromTo(
num_block_per_segment, copyoffset, &final_elem, 0,
sizeof(IdType), ctx, DLContext{kDLCPU, 0},
query_offsets->dtype, thr_entry->stream);
num_blocks += final_elem;
device->FreeWorkspace(ctx, num_block_per_segment);
device->FreeWorkspace(ctx, num_block_prefixsum);
// get batch id and local id in segment
temp_block_size = cuda::FindNumThreads(num_blocks);
temp_num_blocks = (num_blocks - 1) / temp_block_size + 1;
IdType* block_batch_id = static_cast<IdType*>(device->AllocWorkspace(
ctx, num_blocks * sizeof(IdType)));
IdType* local_block_id = static_cast<IdType*>(device->AllocWorkspace(
ctx, num_blocks * sizeof(IdType)));
CUDA_KERNEL_CALL(
get_block_info, temp_num_blocks, temp_block_size, 0,
thr_entry->stream, num_block_prefixsum, block_batch_id,
local_block_id, batch_size, num_blocks);
FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
ctx, k * query_points->shape[0] * sizeof(FloatType)));
CUDA_KERNEL_CALL(bruteforce_knn_share_kernel, num_blocks, block_size,
single_shared_mem * block_size, thr_entry->stream, data_points_data,
data_offsets_data, query_points_data, query_offsets_data,
block_batch_id, local_block_id, k, dists, query_out,
data_out, batch_size, feature_size);
device->FreeWorkspace(ctx, dists);
device->FreeWorkspace(ctx, local_block_id);
device->FreeWorkspace(ctx, block_batch_id);
}
} // namespace impl
template <DLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm) {
if (algorithm == std::string("bruteforce")) {
impl::BruteForceKNNCuda<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result);
} else if (algorithm == std::string("bruteforce-sharemem")) {
impl::BruteForceKNNSharedCuda<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result);
} else {
LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CUDA.";
}
}
template void KNN<kDLGPU, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, float, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, double, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLGPU, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
} // namespace transform
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file graph/transform/knn.cc
* \brief k-nearest-neighbor (KNN) implementation
* \brief k-nearest-neighbor (KNN) interface
*/
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <vector>
#include "kdtree_ndarray_adapter.h"
#include "knn.h"
#include "../../array/check.h"
using namespace dgl::runtime;
using namespace dgl::transform::knn_utils;
namespace dgl {
namespace transform {
namespace impl {
/*! \brief The kd-tree implementation of K-Nearest Neighbors */
template <typename FloatType, typename IdType>
void KdTreeKNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result) {
int64_t batch_size = data_offsets->shape[0] - 1;
int64_t feature_size = data_points->shape[1];
const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
const FloatType* query_points_data = query_points.Ptr<FloatType>();
IdType* query_out = result.Ptr<IdType>();
IdType* data_out = query_out + k * query_points->shape[0];
for (int64_t b = 0; b < batch_size; ++b) {
auto d_offset = data_offsets_data[b];
auto d_length = data_offsets_data[b + 1] - d_offset;
auto q_offset = query_offsets_data[b];
auto q_length = query_offsets_data[b + 1] - q_offset;
auto out_offset = k * q_offset;
// create view for each segment
const NDArray current_data_points = const_cast<NDArray*>(&data_points)->CreateView(
{d_length, feature_size}, data_points->dtype, d_offset * feature_size * sizeof(FloatType));
const FloatType* current_query_pts_data = query_points_data + q_offset * feature_size;
KDTreeNDArrayAdapter<FloatType, IdType> kdtree(feature_size, current_data_points);
// query
std::vector<IdType> out_buffer(k);
std::vector<FloatType> out_dist_buffer(k);
#pragma omp parallel for firstprivate(out_buffer) firstprivate(out_dist_buffer)
for (int64_t q = 0; q < q_length; ++q) {
auto curr_out_offset = k * q + out_offset;
const FloatType* q_point = current_query_pts_data + q * feature_size;
size_t num_matches = kdtree.GetIndex()->knnSearch(
q_point, k, out_buffer.data(), out_dist_buffer.data());
for (size_t i = 0; i < num_matches; ++i) {
query_out[curr_out_offset] = q + q_offset;
data_out[curr_out_offset] = out_buffer[i] + d_offset;
curr_out_offset++;
}
}
}
}
} // namespace impl
template <DLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm) {
if (algorithm == std::string("kd-tree")) {
impl::KdTreeKNN<FloatType, IdType>(
data_points, data_offsets, query_points, query_offsets, k, result);
} else {
LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CPU";
}
}
template void KNN<kDLCPU, float, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, float, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, double, int32_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
template void KNN<kDLCPU, double, int64_t>(
const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string& algorithm);
DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
......@@ -110,7 +30,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLKNN")
data_points->ctx, {data_offsets, query_points, query_offsets, result},
{"data_offsets", "query_points", "query_offsets", "result"});
ATEN_XPU_SWITCH(data_points->ctx.device_type, XPU, "KNN", {
ATEN_XPU_SWITCH_CUDA(data_points->ctx.device_type, XPU, "KNN", {
ATEN_FLOAT_TYPE_SWITCH(data_points->dtype, FloatType, "data_points", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
KNN<XPU, FloatType, IdType>(
......
......@@ -31,7 +31,7 @@ namespace transform {
template <DLDeviceType XPU, typename FloatType, typename IdType>
void KNN(const NDArray& data_points, const IdArray& data_offsets,
const NDArray& query_points, const IdArray& query_offsets,
const int k, IdArray result, const std::string & algorithm);
const int k, IdArray result, const std::string& algorithm);
} // namespace transform
} // namespace dgl
......
......@@ -4,6 +4,8 @@ import dgl
import numpy as np
import pytest
import torch as th
from dgl import DGLError
from dgl.base import DGLWarning
from dgl.geometry.pytorch import FarthestPointSampler
from dgl.geometry import neighbor_matching
from test_utils import parametrize_dtype
......@@ -25,31 +27,148 @@ def test_fps():
assert res.sum() > 0
@pytest.mark.parametrize('algorithm', ['topk', 'kd-tree'])
def test_knn(algorithm):
x = th.randn(8, 3)
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'kd-tree'])
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
def test_knn_cpu(algorithm, dist):
x = th.randn(8, 3).to(F.cpu())
kg = dgl.nn.KNNGraph(3)
d = th.cdist(x, x)
if dist == 'euclidean':
d = th.cdist(x, x).to(F.cpu())
else:
x = x + th.randn(1).item()
tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))
d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())
def check_knn(g, x, start, end, k):
assert g.device == x.device
for v in range(start, end):
src, _ = g.in_edges(v)
src = set(src.numpy())
i = v - start
src_ans = set(th.topk(d[start:end, start:end][i], k, largest=False)[1].numpy() + start)
assert src == src_ans
# check knn with 2d input
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 3)
# check knn with 3d input
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 3)
check_knn(g, x, 4, 8, 3)
def check_knn(g, x, start, end):
# check segmented knn
kg = dgl.nn.SegmentedKNNGraph(3)
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
# check k > num_points
kg = dgl.nn.KNNGraph(10)
with pytest.warns(DGLWarning):
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 8)
with pytest.warns(DGLWarning):
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 4)
check_knn(g, x, 4, 8, 4)
kg = dgl.nn.SegmentedKNNGraph(5)
with pytest.warns(DGLWarning):
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
# check k == 0
kg = dgl.nn.KNNGraph(0)
with pytest.raises(DGLError):
g = kg(x, algorithm, dist)
kg = dgl.nn.SegmentedKNNGraph(0)
with pytest.raises(DGLError):
g = kg(x, [3, 5], algorithm, dist)
# check empty
x_empty = th.tensor([])
kg = dgl.nn.KNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, algorithm, dist)
kg = dgl.nn.SegmentedKNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, [3, 5], algorithm, dist)
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'bruteforce-sharemem'])
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
def test_knn_cuda(algorithm, dist):
if not th.cuda.is_available():
return
x = th.randn(8, 3).to(F.cuda())
kg = dgl.nn.KNNGraph(3)
if dist == 'euclidean':
d = th.cdist(x, x).to(F.cpu())
else:
x = x + th.randn(1).item()
tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))
d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())
def check_knn(g, x, start, end, k):
assert g.device == x.device
g = g.to(F.cpu())
for v in range(start, end):
src, _ = g.in_edges(v)
src = set(src.numpy())
i = v - start
src_ans = set(th.topk(d[start:end, start:end][i], 3, largest=False)[1].numpy() + start)
src_ans = set(th.topk(d[start:end, start:end][i], k, largest=False)[1].numpy() + start)
assert src == src_ans
g = kg(x, algorithm)
check_knn(g, x, 0, 8)
# check knn with 2d input
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 3)
g = kg(x.view(2, 4, 3), algorithm)
check_knn(g, x, 0, 4)
check_knn(g, x, 4, 8)
# check knn with 3d input
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 3)
check_knn(g, x, 4, 8, 3)
# check segmented knn
kg = dgl.nn.SegmentedKNNGraph(3)
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
# check k > num_points
kg = dgl.nn.KNNGraph(10)
with pytest.warns(DGLWarning):
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 8)
with pytest.warns(DGLWarning):
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 4)
check_knn(g, x, 4, 8, 4)
kg = dgl.nn.SegmentedKNNGraph(5)
with pytest.warns(DGLWarning):
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
# check k == 0
kg = dgl.nn.KNNGraph(0)
with pytest.raises(DGLError):
g = kg(x, algorithm, dist)
kg = dgl.nn.SegmentedKNNGraph(0)
with pytest.raises(DGLError):
g = kg(x, [3, 5], algorithm, dist)
# check empty
x_empty = th.tensor([])
kg = dgl.nn.KNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, algorithm, dist)
kg = dgl.nn.SegmentedKNNGraph(3)
g = kg(x, [3, 5], algorithm)
check_knn(g, x, 0, 3)
check_knn(g, x, 3, 8)
with pytest.raises(DGLError):
g = kg(x_empty, [3, 5], algorithm, dist)
@parametrize_dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment