"vscode:/vscode.git/clone" did not exist on "b5511d2cb381435d7dfda27f3fd04b1ff377f974"
knn.cc 23 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/transform/cpu/knn.cc
 * @brief k-nearest-neighbor (KNN) implementation
5
6
 */

7
8
#include "../knn.h"

9
#include <dgl/random.h>
10
#include <dgl/runtime/device_api.h>
11
#include <dgl/runtime/parallel_for.h>
12
#include <dmlc/omp.h>
13

14
#include <algorithm>
15
16
17
18
#include <limits>
#include <tuple>
#include <vector>

19
20
21
22
23
24
25
26
#include "kdtree_ndarray_adapter.h"

using namespace dgl::runtime;
using namespace dgl::transform::knn_utils;
namespace dgl {
namespace transform {
namespace impl {

27
28
29
// This value is directly from pynndescent
static constexpr int NN_DESCENT_BLOCK_SIZE = 16384;

30
/**
31
 * @brief Compute Euclidean distance between two vectors, return positive
32
33
34
35
 *  infinite value if the intermediate distance is greater than the worst
 *  distance.
 */
template <typename FloatType, typename IdType>
36
37
38
FloatType EuclideanDistWithCheck(
    const FloatType* vec1, const FloatType* vec2, int64_t dim,
    FloatType worst_dist = std::numeric_limits<FloatType>::max()) {
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  FloatType dist = 0;
  bool early_stop = false;

  for (IdType idx = 0; idx < dim; ++idx) {
    dist += (vec1[idx] - vec2[idx]) * (vec1[idx] - vec2[idx]);
    if (dist > worst_dist) {
      early_stop = true;
      break;
    }
  }

  if (early_stop) {
    return std::numeric_limits<FloatType>::max();
  } else {
    return dist;
  }
}

57
/** @brief Compute Euclidean distance between two vectors */
58
template <typename FloatType, typename IdType>
59
60
FloatType EuclideanDist(
    const FloatType* vec1, const FloatType* vec2, int64_t dim) {
61
62
63
64
65
66
67
68
69
  FloatType dist = 0;

  for (IdType idx = 0; idx < dim; ++idx) {
    dist += (vec1[idx] - vec2[idx]) * (vec1[idx] - vec2[idx]);
  }

  return dist;
}

70
/** @brief Insert a new element into a heap */
71
template <typename FloatType, typename IdType>
72
73
74
void HeapInsert(
    IdType* out, FloatType* dist, IdType new_id, FloatType new_dist, int k,
    bool check_repeat = false) {
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  if (new_dist > dist[0]) return;

  // check if we have it
  if (check_repeat) {
    for (IdType i = 0; i < k; ++i) {
      if (out[i] == new_id) return;
    }
  }

  IdType left_idx = 0, right_idx = 0, curr_idx = 0, swap_idx = 0;
  dist[0] = new_dist;
  out[0] = new_id;
  while (true) {
    left_idx = 2 * curr_idx + 1;
    right_idx = left_idx + 1;
    swap_idx = curr_idx;
    if (left_idx < k && dist[left_idx] > dist[swap_idx]) {
      swap_idx = left_idx;
    }
    if (right_idx < k && dist[right_idx] > dist[swap_idx]) {
      swap_idx = right_idx;
    }
    if (swap_idx != curr_idx) {
      std::swap(dist[curr_idx], dist[swap_idx]);
      std::swap(out[curr_idx], out[swap_idx]);
      curr_idx = swap_idx;
    } else {
      break;
    }
  }
}

107
/** @brief Insert a new element and its flag into heap, return 1 if insert
108
 * successfully */
109
template <typename FloatType, typename IdType>
110
111
112
int FlaggedHeapInsert(
    IdType* out, FloatType* dist, bool* flag, IdType new_id, FloatType new_dist,
    bool new_flag, int k, bool check_repeat = false) {
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  if (new_dist > dist[0]) return 0;

  if (check_repeat) {
    for (IdType i = 0; i < k; ++i) {
      if (out[i] == new_id) return 0;
    }
  }

  IdType left_idx = 0, right_idx = 0, curr_idx = 0, swap_idx = 0;
  dist[0] = new_dist;
  out[0] = new_id;
  flag[0] = new_flag;
  while (true) {
    left_idx = 2 * curr_idx + 1;
    right_idx = left_idx + 1;
    swap_idx = curr_idx;
    if (left_idx < k && dist[left_idx] > dist[swap_idx]) {
      swap_idx = left_idx;
    }
    if (right_idx < k && dist[right_idx] > dist[swap_idx]) {
      swap_idx = right_idx;
    }
    if (swap_idx != curr_idx) {
      std::swap(dist[curr_idx], dist[swap_idx]);
      std::swap(out[curr_idx], out[swap_idx]);
      std::swap(flag[curr_idx], flag[swap_idx]);
      curr_idx = swap_idx;
    } else {
      break;
    }
  }
  return 1;
}

147
/** @brief Build heap for each point. Used by NN-descent */
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
template <typename FloatType, typename IdType>
void BuildHeap(IdType* index, FloatType* dist, int k) {
  for (int i = k / 2 - 1; i >= 0; --i) {
    IdType idx = i;
    while (true) {
      IdType largest = idx;
      IdType left = idx * 2 + 1;
      IdType right = left + 1;
      if (left < k && dist[left] > dist[largest]) {
        largest = left;
      }
      if (right < k && dist[right] > dist[largest]) {
        largest = right;
      }
      if (largest != idx) {
        std::swap(index[largest], index[idx]);
        std::swap(dist[largest], dist[idx]);
        idx = largest;
      } else {
        break;
      }
    }
  }
}

173
/**
174
 * @brief Neighbor update process in NN-descent. The distance between
175
176
177
178
 *  two points are computed. If this new distance is less than any worst
 *  distance of these two points, we update the neighborhood of that point.
 */
template <typename FloatType, typename IdType>
179
180
181
int UpdateNeighbors(
    IdType* neighbors, FloatType* dists, const FloatType* points, bool* flags,
    IdType c1, IdType c2, IdType point_start, int64_t feature_size, int k) {
182
183
184
185
  IdType c1_local = c1 - point_start, c2_local = c2 - point_start;
  FloatType worst_c1_dist = dists[c1_local * k];
  FloatType worst_c2_dist = dists[c2_local * k];
  FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
186
187
      points + c1 * feature_size, points + c2 * feature_size, feature_size,
      std::max(worst_c1_dist, worst_c2_dist));
188
189
190
191
192
193
194

  int num_updates = 0;
  if (new_dist < worst_c1_dist) {
    ++num_updates;
#pragma omp critical
    {
      FlaggedHeapInsert<FloatType, IdType>(
195
196
          neighbors + c1 * k, dists + c1_local * k, flags + c1_local * k, c2,
          new_dist, true, k, true);
197
198
199
200
201
202
203
    }
  }
  if (new_dist < worst_c2_dist) {
    ++num_updates;
#pragma omp critical
    {
      FlaggedHeapInsert<FloatType, IdType>(
204
205
          neighbors + c2 * k, dists + c2_local * k, flags + c2_local * k, c1,
          new_dist, true, k, true);
206
207
208
209
210
    }
  }
  return num_updates;
}

211
/** @brief The kd-tree implementation of K-Nearest Neighbors */
212
template <typename FloatType, typename IdType>
213
214
215
216
void KdTreeKNN(
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result) {
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
  const int64_t batch_size = data_offsets->shape[0] - 1;
  const int64_t feature_size = data_points->shape[1];
  const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
  const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
  const FloatType* query_points_data = query_points.Ptr<FloatType>();
  IdType* query_out = result.Ptr<IdType>();
  IdType* data_out = query_out + k * query_points->shape[0];

  for (int64_t b = 0; b < batch_size; ++b) {
    auto d_offset = data_offsets_data[b];
    auto d_length = data_offsets_data[b + 1] - d_offset;
    auto q_offset = query_offsets_data[b];
    auto q_length = query_offsets_data[b + 1] - q_offset;
    auto out_offset = k * q_offset;

    // create view for each segment
233
234
235
236
237
238
239
    const NDArray current_data_points =
        const_cast<NDArray*>(&data_points)
            ->CreateView(
                {d_length, feature_size}, data_points->dtype,
                d_offset * feature_size * sizeof(FloatType));
    const FloatType* current_query_pts_data =
        query_points_data + q_offset * feature_size;
240

241
242
    KDTreeNDArrayAdapter<FloatType, IdType> kdtree(
        feature_size, current_data_points);
243
244

    // query
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    parallel_for(0, q_length, [&](IdType b, IdType e) {
      for (auto q = b; q < e; ++q) {
        std::vector<IdType> out_buffer(k);
        std::vector<FloatType> out_dist_buffer(k);

        auto curr_out_offset = k * q + out_offset;
        const FloatType* q_point = current_query_pts_data + q * feature_size;
        size_t num_matches = kdtree.GetIndex()->knnSearch(
            q_point, k, out_buffer.data(), out_dist_buffer.data());

        for (size_t i = 0; i < num_matches; ++i) {
          query_out[curr_out_offset] = q + q_offset;
          data_out[curr_out_offset] = out_buffer[i] + d_offset;
          curr_out_offset++;
        }
260
      }
261
    });
262
263
264
265
  }
}

template <typename FloatType, typename IdType>
266
267
268
269
void BruteForceKNN(
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result) {
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
  const int64_t batch_size = data_offsets->shape[0] - 1;
  const int64_t feature_size = data_points->shape[1];
  const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
  const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
  const FloatType* data_points_data = data_points.Ptr<FloatType>();
  const FloatType* query_points_data = query_points.Ptr<FloatType>();
  IdType* query_out = result.Ptr<IdType>();
  IdType* data_out = query_out + k * query_points->shape[0];

  for (int64_t b = 0; b < batch_size; ++b) {
    IdType d_start = data_offsets_data[b], d_end = data_offsets_data[b + 1];
    IdType q_start = query_offsets_data[b], q_end = query_offsets_data[b + 1];

    std::vector<FloatType> dist_buffer(k);

285
286
287
288
289
290
291
292
    parallel_for(q_start, q_end, [&](IdType b, IdType e) {
      for (auto q_idx = b; q_idx < e; ++q_idx) {
        std::vector<FloatType> dist_buffer(k);
        for (IdType k_idx = 0; k_idx < k; ++k_idx) {
          query_out[q_idx * k + k_idx] = q_idx;
          dist_buffer[k_idx] = std::numeric_limits<FloatType>::max();
        }
        FloatType worst_dist = std::numeric_limits<FloatType>::max();
293

294
295
        for (IdType d_idx = d_start; d_idx < d_end; ++d_idx) {
          FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>(
296
297
298
              query_points_data + q_idx * feature_size,
              data_points_data + d_idx * feature_size, feature_size,
              worst_dist);
299

300
301
302
          if (tmp_dist == std::numeric_limits<FloatType>::max()) {
            continue;
          }
303

304
305
          IdType out_offset = q_idx * k;
          HeapInsert<FloatType, IdType>(
306
              data_out + out_offset, dist_buffer.data(), d_idx, tmp_dist, k);
307
308
          worst_dist = dist_buffer[0];
        }
309
      }
310
    });
311
312
313
314
  }
}
}  // namespace impl

315
template <DGLDeviceType XPU, typename FloatType, typename IdType>
316
317
318
319
void KNN(
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm) {
320
321
  if (algorithm == std::string("kd-tree")) {
    impl::KdTreeKNN<FloatType, IdType>(
322
        data_points, data_offsets, query_points, query_offsets, k, result);
323
324
  } else if (algorithm == std::string("bruteforce")) {
    impl::BruteForceKNN<FloatType, IdType>(
325
        data_points, data_offsets, query_points, query_offsets, k, result);
326
327
328
329
330
  } else {
    LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CPU";
  }
}

331
template <DGLDeviceType XPU, typename FloatType, typename IdType>
332
333
334
335
336
void NNDescent(
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta) {
  using nnd_updates_t =
      std::vector<std::vector<std::tuple<IdType, IdType, FloatType>>>;
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
  const auto& ctx = points->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  const int64_t num_nodes = points->shape[0];
  const int64_t batch_size = offsets->shape[0] - 1;
  const int64_t feature_size = points->shape[1];
  const IdType* offsets_data = offsets.Ptr<IdType>();
  const FloatType* points_data = points.Ptr<FloatType>();

  IdType* central_nodes = result.Ptr<IdType>();
  IdType* neighbors = central_nodes + k * num_nodes;
  int64_t max_segment_size = 0;

  // find max segment
  for (IdType b = 0; b < batch_size; ++b) {
    if (max_segment_size < offsets_data[b + 1] - offsets_data[b])
      max_segment_size = offsets_data[b + 1] - offsets_data[b];
  }

  // allocate memory for candidate, sampling pool, distance and flag
356
357
358
359
360
361
362
363
364
365
  IdType* new_candidates = static_cast<IdType*>(device->AllocWorkspace(
      ctx, max_segment_size * num_candidates * sizeof(IdType)));
  IdType* old_candidates = static_cast<IdType*>(device->AllocWorkspace(
      ctx, max_segment_size * num_candidates * sizeof(IdType)));
  FloatType* new_candidates_dists =
      static_cast<FloatType*>(device->AllocWorkspace(
          ctx, max_segment_size * num_candidates * sizeof(FloatType)));
  FloatType* old_candidates_dists =
      static_cast<FloatType*>(device->AllocWorkspace(
          ctx, max_segment_size * num_candidates * sizeof(FloatType)));
366
  FloatType* neighbors_dists = static_cast<FloatType*>(
367
      device->AllocWorkspace(ctx, max_segment_size * k * sizeof(FloatType)));
368
  bool* flags = static_cast<bool*>(
369
      device->AllocWorkspace(ctx, max_segment_size * k * sizeof(bool)));
370
371

  for (IdType b = 0; b < batch_size; ++b) {
372
373
    IdType point_idx_start = offsets_data[b],
           point_idx_end = offsets_data[b + 1];
374
375
376
    IdType segment_size = point_idx_end - point_idx_start;

    // random initialization
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    runtime::parallel_for(
        point_idx_start, point_idx_end, [&](size_t b, size_t e) {
          for (auto i = b; i < e; ++i) {
            IdType local_idx = i - point_idx_start;

            dgl::RandomEngine::ThreadLocal()->UniformChoice<IdType>(
                k, segment_size, neighbors + i * k, false);

            for (IdType n = 0; n < k; ++n) {
              central_nodes[i * k + n] = i;
              neighbors[i * k + n] += point_idx_start;
              flags[local_idx * k + n] = true;
              neighbors_dists[local_idx * k + n] =
                  impl::EuclideanDist<FloatType, IdType>(
                      points_data + i * feature_size,
                      points_data + neighbors[i * k + n] * feature_size,
                      feature_size);
            }
            impl::BuildHeap<FloatType, IdType>(
                neighbors + i * k, neighbors_dists + local_idx * k, k);
          }
        });
399
400
401
402
403
404

    size_t num_updates = 0;
    for (int iter = 0; iter < num_iters; ++iter) {
      num_updates = 0;

      // initialize candidates array as empty value
405
406
407
408
409
410
411
412
413
414
415
416
417
418
      runtime::parallel_for(
          point_idx_start, point_idx_end, [&](size_t b, size_t e) {
            for (auto i = b; i < e; ++i) {
              IdType local_idx = i - point_idx_start;
              for (IdType c = 0; c < num_candidates; ++c) {
                new_candidates[local_idx * num_candidates + c] = num_nodes;
                old_candidates[local_idx * num_candidates + c] = num_nodes;
                new_candidates_dists[local_idx * num_candidates + c] =
                    std::numeric_limits<FloatType>::max();
                old_candidates_dists[local_idx * num_candidates + c] =
                    std::numeric_limits<FloatType>::max();
              }
            }
          });
419
420

      // randomly select neighbors as candidates
421
      int num_threads = omp_get_max_threads();
422
      runtime::parallel_for(0, num_threads, [&](IdType b, IdType e) {
423
424
425
426
427
428
429
        for (auto tid = b; tid < e; ++tid) {
          for (IdType i = point_idx_start; i < point_idx_end; ++i) {
            IdType local_idx = i - point_idx_start;
            for (IdType n = 0; n < k; ++n) {
              IdType neighbor_idx = neighbors[i * k + n];
              bool is_new = flags[local_idx * k + n];
              IdType local_neighbor_idx = neighbor_idx - point_idx_start;
430
431
              FloatType random_dist =
                  dgl::RandomEngine::ThreadLocal()->Uniform<FloatType>();
432
433
434
435

              if (is_new) {
                if (local_idx % num_threads == tid) {
                  impl::HeapInsert<FloatType, IdType>(
436
437
438
                      new_candidates + local_idx * num_candidates,
                      new_candidates_dists + local_idx * num_candidates,
                      neighbor_idx, random_dist, num_candidates, true);
439
440
441
                }
                if (local_neighbor_idx % num_threads == tid) {
                  impl::HeapInsert<FloatType, IdType>(
442
443
444
445
                      new_candidates + local_neighbor_idx * num_candidates,
                      new_candidates_dists +
                          local_neighbor_idx * num_candidates,
                      i, random_dist, num_candidates, true);
446
447
448
449
                }
              } else {
                if (local_idx % num_threads == tid) {
                  impl::HeapInsert<FloatType, IdType>(
450
451
452
                      old_candidates + local_idx * num_candidates,
                      old_candidates_dists + local_idx * num_candidates,
                      neighbor_idx, random_dist, num_candidates, true);
453
454
455
                }
                if (local_neighbor_idx % num_threads == tid) {
                  impl::HeapInsert<FloatType, IdType>(
456
457
458
459
                      old_candidates + local_neighbor_idx * num_candidates,
                      old_candidates_dists +
                          local_neighbor_idx * num_candidates,
                      i, random_dist, num_candidates, true);
460
                }
461
462
463
464
              }
            }
          }
        }
465
      });
466
467

      // mark all elements in new_candidates as false
468
469
470
471
472
473
474
475
476
477
478
479
480
      runtime::parallel_for(
          point_idx_start, point_idx_end, [&](size_t b, size_t e) {
            for (auto i = b; i < e; ++i) {
              IdType local_idx = i - point_idx_start;
              for (IdType n = 0; n < k; ++n) {
                IdType n_idx = neighbors[i * k + n];

                for (IdType c = 0; c < num_candidates; ++c) {
                  if (new_candidates[local_idx * num_candidates + c] == n_idx) {
                    flags[local_idx * k + n] = false;
                    break;
                  }
                }
481
              }
482
            }
483
          });
484
485

      // update neighbors block by block
486
      for (IdType block_start = point_idx_start; block_start < point_idx_end;
487
           block_start += impl::NN_DESCENT_BLOCK_SIZE) {
488
489
        IdType block_end =
            std::min(point_idx_end, block_start + impl::NN_DESCENT_BLOCK_SIZE);
490
491
492
493
        IdType block_size = block_end - block_start;
        nnd_updates_t updates(block_size);

        // generate updates
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        runtime::parallel_for(block_start, block_end, [&](size_t b, size_t e) {
          for (auto i = b; i < e; ++i) {
            IdType local_idx = i - point_idx_start;

            for (IdType c1 = 0; c1 < num_candidates; ++c1) {
              IdType new_c1 = new_candidates[local_idx * num_candidates + c1];
              if (new_c1 == num_nodes) continue;
              IdType c1_local = new_c1 - point_idx_start;

              // new-new
              for (IdType c2 = c1; c2 < num_candidates; ++c2) {
                IdType new_c2 = new_candidates[local_idx * num_candidates + c2];
                if (new_c2 == num_nodes) continue;
                IdType c2_local = new_c2 - point_idx_start;

                FloatType worst_c1_dist = neighbors_dists[c1_local * k];
                FloatType worst_c2_dist = neighbors_dists[c2_local * k];
511
512
513
514
515
                FloatType new_dist =
                    impl::EuclideanDistWithCheck<FloatType, IdType>(
                        points_data + new_c1 * feature_size,
                        points_data + new_c2 * feature_size, feature_size,
                        std::max(worst_c1_dist, worst_c2_dist));
516
517

                if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) {
518
519
                  updates[i - block_start].push_back(
                      std::make_tuple(new_c1, new_c2, new_dist));
520
                }
521
522
              }

523
524
525
526
527
528
529
530
              // new-old
              for (IdType c2 = 0; c2 < num_candidates; ++c2) {
                IdType old_c2 = old_candidates[local_idx * num_candidates + c2];
                if (old_c2 == num_nodes) continue;
                IdType c2_local = old_c2 - point_idx_start;

                FloatType worst_c1_dist = neighbors_dists[c1_local * k];
                FloatType worst_c2_dist = neighbors_dists[c2_local * k];
531
532
533
534
535
                FloatType new_dist =
                    impl::EuclideanDistWithCheck<FloatType, IdType>(
                        points_data + new_c1 * feature_size,
                        points_data + old_c2 * feature_size, feature_size,
                        std::max(worst_c1_dist, worst_c2_dist));
536
537

                if (new_dist < worst_c1_dist || new_dist < worst_c2_dist) {
538
539
                  updates[i - block_start].push_back(
                      std::make_tuple(new_c1, old_c2, new_dist));
540
                }
541
542
543
              }
            }
          }
544
        });
545

546
        int tid;
547
#pragma omp parallel private(tid, num_threads) reduction(+ : num_updates)
548
549
550
551
        {
          tid = omp_get_thread_num();
          num_threads = omp_get_num_threads();
          for (IdType i = 0; i < block_size; ++i) {
552
            for (const auto& u : updates[i]) {
553
554
555
556
557
558
559
560
              IdType p1, p2;
              FloatType d;
              std::tie(p1, p2, d) = u;
              IdType p1_local = p1 - point_idx_start;
              IdType p2_local = p2 - point_idx_start;

              if (p1 % num_threads == tid) {
                num_updates += impl::FlaggedHeapInsert<FloatType, IdType>(
561
562
                    neighbors + p1 * k, neighbors_dists + p1_local * k,
                    flags + p1_local * k, p2, d, true, k, true);
563
564
565
              }
              if (p2 % num_threads == tid) {
                num_updates += impl::FlaggedHeapInsert<FloatType, IdType>(
566
567
                    neighbors + p2 * k, neighbors_dists + p2_local * k,
                    flags + p2_local * k, p1, d, true, k, true);
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
              }
            }
          }
        }
      }

      // early abort
      if (num_updates <= static_cast<size_t>(delta * k * segment_size)) {
        break;
      }
    }
  }

  device->FreeWorkspace(ctx, new_candidates);
  device->FreeWorkspace(ctx, old_candidates);
  device->FreeWorkspace(ctx, new_candidates_dists);
  device->FreeWorkspace(ctx, old_candidates_dists);
  device->FreeWorkspace(ctx, neighbors_dists);
  device->FreeWorkspace(ctx, flags);
}

589
template void KNN<kDGLCPU, float, int32_t>(
590
591
592
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
593
template void KNN<kDGLCPU, float, int64_t>(
594
595
596
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
597
template void KNN<kDGLCPU, double, int32_t>(
598
599
600
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
601
template void KNN<kDGLCPU, double, int64_t>(
602
603
604
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
605

606
template void NNDescent<kDGLCPU, float, int32_t>(
607
608
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
609
template void NNDescent<kDGLCPU, float, int64_t>(
610
611
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
612
template void NNDescent<kDGLCPU, double, int32_t>(
613
614
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
615
template void NNDescent<kDGLCPU, double, int64_t>(
616
617
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
618
619
}  // namespace transform
}  // namespace dgl