"examples/pytorch/eeg-gcnn/deep_EEGGraphConvNet.py" did not exist on "cd2cf606db954b0ce775833bacdb2994608fc7c1"
knn.cu 34.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file graph/transform/cuda/knn.cu
 * @brief k-nearest-neighbor (KNN) implementation (cuda)
5
6
 */

7
#include <curand_kernel.h>
8
#include <dgl/array.h>
9
#include <dgl/random.h>
10
#include <dgl/runtime/device_api.h>
11

12
#include <algorithm>
13
#include <limits>
14
15
#include <string>
#include <vector>
16

17
18
#include "../../../array/cuda/dgl_cub.cuh"
#include "../../../array/cuda/utils.h"
19
#include "../../../runtime/cuda/cuda_common.h"
20
21
22
23
24
#include "../knn.h"

namespace dgl {
namespace transform {
namespace impl {
25
/**
26
 * @brief Utility class used to avoid linker errors with extern
27
28
29
30
 *  unsized shared memory arrays with templated type
 */
template <typename Type>
struct SharedMemory {
31
  __device__ inline operator Type*() {
32
33
34
35
    extern __shared__ int __smem[];
    return reinterpret_cast<Type*>(__smem);
  }

36
  __device__ inline operator const Type*() const {
37
38
39
40
41
42
43
44
45
    extern __shared__ int __smem[];
    return reinterpret_cast<Type*>(__smem);
  }
};

// specialize for double to avoid unaligned memory
// access compile errors
template <>
struct SharedMemory<double> {
46
  __device__ inline operator double*() {
47
48
49
50
    extern __shared__ double __smem_d[];
    return reinterpret_cast<double*>(__smem_d);
  }

51
  __device__ inline operator const double*() const {
52
53
54
55
56
    extern __shared__ double __smem_d[];
    return reinterpret_cast<double*>(__smem_d);
  }
};

57
/** @brief Compute Euclidean distance between two vectors in a cuda kernel */
58
template <typename FloatType, typename IdType>
59
60
__device__ FloatType
EuclideanDist(const FloatType* vec1, const FloatType* vec2, const int64_t dim) {
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  FloatType dist = 0;
  IdType idx = 0;
  for (; idx < dim - 3; idx += 4) {
    FloatType diff0 = vec1[idx] - vec2[idx];
    FloatType diff1 = vec1[idx + 1] - vec2[idx + 1];
    FloatType diff2 = vec1[idx + 2] - vec2[idx + 2];
    FloatType diff3 = vec1[idx + 3] - vec2[idx + 3];

    dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
  }

  for (; idx < dim; ++idx) {
    FloatType diff = vec1[idx] - vec2[idx];
    dist += diff * diff;
  }

  return dist;
}

80
/**
81
 * @brief Compute Euclidean distance between two vectors in a cuda kernel,
82
83
84
85
 *  return positive infinite value if the intermediate distance is greater
 *  than the worst distance.
 */
template <typename FloatType, typename IdType>
86
87
88
__device__ FloatType EuclideanDistWithCheck(
    const FloatType* vec1, const FloatType* vec2, const int64_t dim,
    const FloatType worst_dist) {
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  FloatType dist = 0;
  IdType idx = 0;
  bool early_stop = false;

  for (; idx < dim - 3; idx += 4) {
    FloatType diff0 = vec1[idx] - vec2[idx];
    FloatType diff1 = vec1[idx + 1] - vec2[idx + 1];
    FloatType diff2 = vec1[idx + 2] - vec2[idx + 2];
    FloatType diff3 = vec1[idx + 3] - vec2[idx + 3];

    dist += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
    if (dist > worst_dist) {
      early_stop = true;
      idx = dim;
      break;
    }
  }

  for (; idx < dim; ++idx) {
    FloatType diff = vec1[idx] - vec2[idx];
    dist += diff * diff;
    if (dist > worst_dist) {
      early_stop = true;
      break;
    }
  }

  if (early_stop) {
    return std::numeric_limits<FloatType>::max();
  } else {
    return dist;
  }
}

template <typename FloatType, typename IdType>
__device__ void BuildHeap(IdType* indices, FloatType* dists, int size) {
  for (int i = size / 2 - 1; i >= 0; --i) {
    IdType idx = i;
    while (true) {
      IdType largest = idx;
      IdType left = idx * 2 + 1;
      IdType right = left + 1;
      if (left < size && dists[left] > dists[largest]) {
        largest = left;
      }
      if (right < size && dists[right] > dists[largest]) {
        largest = right;
      }
      if (largest != idx) {
        IdType tmp_idx = indices[largest];
        indices[largest] = indices[idx];
        indices[idx] = tmp_idx;

        FloatType tmp_dist = dists[largest];
        dists[largest] = dists[idx];
        dists[idx] = tmp_dist;
        idx = largest;
      } else {
        break;
      }
    }
  }
}

template <typename FloatType, typename IdType>
154
155
156
__device__ void HeapInsert(
    IdType* indices, FloatType* dist, IdType new_idx, FloatType new_dist,
    int size, bool check_repeat = false) {
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
  if (new_dist > dist[0]) return;

  // check if we have it
  if (check_repeat) {
    for (IdType i = 0; i < size; ++i) {
      if (indices[i] == new_idx) return;
    }
  }

  IdType left = 0, right = 0, idx = 0, largest = 0;
  dist[0] = new_dist;
  indices[0] = new_idx;
  while (true) {
    left = idx * 2 + 1;
    right = left + 1;
    if (left < size && dist[left] > dist[largest]) {
      largest = left;
    }
    if (right < size && dist[right] > dist[largest]) {
      largest = right;
    }
    if (largest != idx) {
      IdType tmp_idx = indices[idx];
      indices[idx] = indices[largest];
      indices[largest] = tmp_idx;

      FloatType tmp_dist = dist[idx];
      dist[idx] = dist[largest];
      dist[largest] = tmp_dist;

      idx = largest;
    } else {
      break;
    }
  }
}

template <typename FloatType, typename IdType>
195
196
197
__device__ bool FlaggedHeapInsert(
    IdType* indices, FloatType* dist, bool* flags, IdType new_idx,
    FloatType new_dist, bool new_flag, int size, bool check_repeat = false) {
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
  if (new_dist > dist[0]) return false;

  // check if we have it
  if (check_repeat) {
    for (IdType i = 0; i < size; ++i) {
      if (indices[i] == new_idx) return false;
    }
  }

  IdType left = 0, right = 0, idx = 0, largest = 0;
  dist[0] = new_dist;
  indices[0] = new_idx;
  flags[0] = new_flag;
  while (true) {
    left = idx * 2 + 1;
    right = left + 1;
    if (left < size && dist[left] > dist[largest]) {
      largest = left;
    }
    if (right < size && dist[right] > dist[largest]) {
      largest = right;
    }
    if (largest != idx) {
      IdType tmp_idx = indices[idx];
      indices[idx] = indices[largest];
      indices[largest] = tmp_idx;

      FloatType tmp_dist = dist[idx];
      dist[idx] = dist[largest];
      dist[largest] = tmp_dist;

      bool tmp_flag = flags[idx];
      flags[idx] = flags[largest];
      flags[largest] = tmp_flag;

      idx = largest;
    } else {
      break;
    }
  }
  return true;
}

241
/**
242
 * @brief Brute force kNN kernel. Compute distance for each pair of input points
243
 * and get the result directly (without a distance matrix).
244
245
 */
template <typename FloatType, typename IdType>
246
247
248
249
250
__global__ void BruteforceKnnKernel(
    const FloatType* data_points, const IdType* data_offsets,
    const FloatType* query_points, const IdType* query_offsets, const int k,
    FloatType* dists, IdType* query_out, IdType* data_out,
    const int64_t num_batches, const int64_t feature_size) {
251
  const IdType q_idx = blockIdx.x * blockDim.x + threadIdx.x;
252
  if (q_idx >= query_offsets[num_batches]) return;
253
254
  IdType batch_idx = 0;
  for (IdType b = 0; b < num_batches + 1; ++b) {
255
256
257
258
    if (query_offsets[b] > q_idx) {
      batch_idx = b - 1;
      break;
    }
259
  }
260
261
  const IdType data_start = data_offsets[batch_idx],
               data_end = data_offsets[batch_idx + 1];
262
263
264
265
266
267
268
269

  for (IdType k_idx = 0; k_idx < k; ++k_idx) {
    query_out[q_idx * k + k_idx] = q_idx;
    dists[q_idx * k + k_idx] = std::numeric_limits<FloatType>::max();
  }
  FloatType worst_dist = std::numeric_limits<FloatType>::max();

  for (IdType d_idx = data_start; d_idx < data_end; ++d_idx) {
270
    FloatType tmp_dist = EuclideanDistWithCheck<FloatType, IdType>(
271
272
        query_points + q_idx * feature_size, data_points + d_idx * feature_size,
        feature_size, worst_dist);
273
274

    IdType out_offset = q_idx * k;
275
276
    HeapInsert<FloatType, IdType>(
        data_out + out_offset, dists + out_offset, d_idx, tmp_dist, k);
277
    worst_dist = dists[q_idx * k];
278
279
280
  }
}

281
/**
282
 * @brief Same as BruteforceKnnKernel, but use shared memory as buffer.
283
284
285
286
287
 *  This kernel divides query points and data points into blocks. For each
 *  query block, it will make a loop over all data blocks and compute distances.
 *  This kernel is faster when the dimension of input points is not large.
 */
template <typename FloatType, typename IdType>
288
289
290
291
292
293
__global__ void BruteforceKnnShareKernel(
    const FloatType* data_points, const IdType* data_offsets,
    const FloatType* query_points, const IdType* query_offsets,
    const IdType* block_batch_id, const IdType* local_block_id, const int k,
    FloatType* dists, IdType* query_out, IdType* data_out,
    const int64_t num_batches, const int64_t feature_size) {
294
295
296
297
298
  const IdType block_idx = static_cast<IdType>(blockIdx.x);
  const IdType block_size = static_cast<IdType>(blockDim.x);
  const IdType batch_idx = block_batch_id[block_idx];
  const IdType local_bid = local_block_id[block_idx];
  const IdType query_start = query_offsets[batch_idx] + block_size * local_bid;
299
300
  const IdType query_end =
      min(query_start + block_size, query_offsets[batch_idx + 1]);
301
  if (query_start >= query_end) return;
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
  const IdType query_idx = query_start + threadIdx.x;
  const IdType data_start = data_offsets[batch_idx];
  const IdType data_end = data_offsets[batch_idx + 1];

  // shared memory: points in block + distance buffer + result buffer
  FloatType* data_buff = SharedMemory<FloatType>();
  FloatType* query_buff = data_buff + block_size * feature_size;
  FloatType* dist_buff = query_buff + block_size * feature_size;
  IdType* res_buff = reinterpret_cast<IdType*>(dist_buff + block_size * k);
  FloatType worst_dist = std::numeric_limits<FloatType>::max();

  // initialize dist buff with inf value
  for (auto i = 0; i < k; ++i) {
    dist_buff[threadIdx.x * k + i] = std::numeric_limits<FloatType>::max();
  }

  // load query data to shared memory
  if (query_idx < query_end) {
    for (auto i = 0; i < feature_size; ++i) {
      // to avoid bank conflict, we use transpose here
322
323
      query_buff[threadIdx.x + i * block_size] =
          query_points[query_idx * feature_size + i];
324
325
326
327
    }
  }

  // perform computation on each tile
328
329
  for (auto tile_start = data_start; tile_start < data_end;
       tile_start += block_size) {
330
331
332
333
    // each thread load one data point into the shared memory
    IdType load_idx = tile_start + threadIdx.x;
    if (load_idx < data_end) {
      for (auto i = 0; i < feature_size; ++i) {
334
335
        data_buff[threadIdx.x * feature_size + i] =
            data_points[load_idx * feature_size + i];
336
337
338
339
340
341
342
343
344
345
346
347
348
      }
    }
    __syncthreads();

    // compute distance for one tile
    IdType true_block_size = min(data_end - tile_start, block_size);
    if (query_idx < query_end) {
      for (IdType d_idx = 0; d_idx < true_block_size; ++d_idx) {
        FloatType tmp_dist = 0;
        bool early_stop = false;
        IdType dim_idx = 0;

        for (; dim_idx < feature_size - 3; dim_idx += 4) {
349
350
351
352
353
354
355
356
357
358
359
360
361
362
          FloatType diff0 = query_buff[threadIdx.x + block_size * (dim_idx)] -
                            data_buff[d_idx * feature_size + dim_idx];
          FloatType diff1 =
              query_buff[threadIdx.x + block_size * (dim_idx + 1)] -
              data_buff[d_idx * feature_size + dim_idx + 1];
          FloatType diff2 =
              query_buff[threadIdx.x + block_size * (dim_idx + 2)] -
              data_buff[d_idx * feature_size + dim_idx + 2];
          FloatType diff3 =
              query_buff[threadIdx.x + block_size * (dim_idx + 3)] -
              data_buff[d_idx * feature_size + dim_idx + 3];

          tmp_dist +=
              diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
363
364
365
366
367
368
369
370
371

          if (tmp_dist > worst_dist) {
            early_stop = true;
            dim_idx = feature_size;
            break;
          }
        }

        for (; dim_idx < feature_size; ++dim_idx) {
372
373
374
          const FloatType diff =
              query_buff[threadIdx.x + dim_idx * block_size] -
              data_buff[d_idx * feature_size + dim_idx];
375
376
377
378
379
380
381
382
383
384
          tmp_dist += diff * diff;

          if (tmp_dist > worst_dist) {
            early_stop = true;
            break;
          }
        }

        if (early_stop) continue;

385
        HeapInsert<FloatType, IdType>(
386
387
            res_buff + threadIdx.x * k, dist_buff + threadIdx.x * k,
            d_idx + tile_start, tmp_dist, k);
388
        worst_dist = dist_buff[threadIdx.x * k];
389
390
391
392
393
394
395
396
397
398
399
400
401
402
      }
    }
  }

  // copy result to global memory
  if (query_idx < query_end) {
    for (auto i = 0; i < k; ++i) {
      dists[query_idx * k + i] = dist_buff[threadIdx.x * k + i];
      data_out[query_idx * k + i] = res_buff[threadIdx.x * k + i];
      query_out[query_idx * k + i] = query_idx;
    }
  }
}

403
/** @brief determine the number of blocks for each segment */
404
template <typename IdType>
405
406
407
__global__ void GetNumBlockPerSegment(
    const IdType* offsets, IdType* out, const int64_t batch_size,
    const int64_t block_size) {
408
409
410
411
412
413
  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < batch_size) {
    out[idx] = (offsets[idx + 1] - offsets[idx] - 1) / block_size + 1;
  }
}

414
/** @brief Get the batch index and local index in segment for each block */
415
template <typename IdType>
416
417
418
__global__ void GetBlockInfo(
    const IdType* num_block_prefixsum, IdType* block_batch_id,
    IdType* local_block_id, size_t batch_size, size_t num_blocks) {
419
420
421
422
423
424
425
426
427
428
429
430
431
  const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
  IdType i = 0;

  if (idx < num_blocks) {
    for (; i < batch_size; ++i) {
      if (num_block_prefixsum[i] > idx) break;
    }
    i--;
    block_batch_id[idx] = i;
    local_block_id[idx] = idx - num_block_prefixsum[i];
  }
}

432
/**
433
 * @brief Brute force kNN. Compute distance for each pair of input points and
434
 * get the result directly (without a distance matrix).
435
 *
436
437
438
439
440
441
442
443
 * @tparam FloatType The type of input points.
 * @tparam IdType The type of id.
 * @param data_points NDArray of dataset points.
 * @param data_offsets offsets of point index in data points.
 * @param query_points NDArray of query points
 * @param query_offsets offsets of point index in query points.
 * @param k the number of nearest points
 * @param result output array
444
445
 */
template <typename FloatType, typename IdType>
446
447
448
449
void BruteForceKNNCuda(
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result) {
450
  cudaStream_t stream = runtime::getCurrentCUDAStream();
451
452
453
454
455
456
457
458
459
460
461
462
  const auto& ctx = data_points->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  const int64_t batch_size = data_offsets->shape[0] - 1;
  const int64_t feature_size = data_points->shape[1];
  const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
  const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
  const FloatType* data_points_data = data_points.Ptr<FloatType>();
  const FloatType* query_points_data = query_points.Ptr<FloatType>();
  IdType* query_out = result.Ptr<IdType>();
  IdType* data_out = query_out + k * query_points->shape[0];

  FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
463
      ctx, k * query_points->shape[0] * sizeof(FloatType)));
464
465
466

  const int64_t block_size = cuda::FindNumThreads(query_points->shape[0]);
  const int64_t num_blocks = (query_points->shape[0] - 1) / block_size + 1;
467
468
469
470
  CUDA_KERNEL_CALL(
      BruteforceKnnKernel, num_blocks, block_size, 0, stream, data_points_data,
      data_offsets_data, query_points_data, query_offsets_data, k, dists,
      query_out, data_out, batch_size, feature_size);
471
472
473
474

  device->FreeWorkspace(ctx, dists);
}

475
/**
476
 * @brief Brute force kNN with shared memory.
477
478
479
480
 *  This function divides query points and data points into blocks. For each
 *  query block, it will make a loop over all data blocks and compute distances.
 *  It will be faster when the dimension of input points is not large.
 *
481
482
483
484
485
486
487
488
 * @tparam FloatType The type of input points.
 * @tparam IdType The type of id.
 * @param data_points NDArray of dataset points.
 * @param data_offsets offsets of point index in data points.
 * @param query_points NDArray of query points
 * @param query_offsets offsets of point index in query points.
 * @param k the number of nearest points
 * @param result output array
489
490
 */
template <typename FloatType, typename IdType>
491
492
493
494
void BruteForceKNNSharedCuda(
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result) {
495
  cudaStream_t stream = runtime::getCurrentCUDAStream();
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
  const auto& ctx = data_points->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  const int64_t batch_size = data_offsets->shape[0] - 1;
  const int64_t feature_size = data_points->shape[1];
  const IdType* data_offsets_data = data_offsets.Ptr<IdType>();
  const IdType* query_offsets_data = query_offsets.Ptr<IdType>();
  const FloatType* data_points_data = data_points.Ptr<FloatType>();
  const FloatType* query_points_data = query_points.Ptr<FloatType>();
  IdType* query_out = result.Ptr<IdType>();
  IdType* data_out = query_out + k * query_points->shape[0];

  // get max shared memory per block in bytes
  // determine block size according to this value
  int max_sharedmem_per_block = 0;
  CUDA_CALL(cudaDeviceGetAttribute(
511
512
513
514
515
516
      &max_sharedmem_per_block, cudaDevAttrMaxSharedMemoryPerBlock,
      ctx.device_id));
  const int64_t single_shared_mem =
      (k + 2 * feature_size) * sizeof(FloatType) + k * sizeof(IdType);
  const int64_t block_size =
      cuda::FindNumThreads(max_sharedmem_per_block / single_shared_mem);
517
518
519
520

  // Determine the number of blocks. We first get the number of blocks for each
  // segment. Then we get the block id offset via prefix sum.
  IdType* num_block_per_segment = static_cast<IdType*>(
521
      device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));
522
  IdType* num_block_prefixsum = static_cast<IdType*>(
523
      device->AllocWorkspace(ctx, batch_size * sizeof(IdType)));
524

525
  // block size for GetNumBlockPerSegment computation
526
527
  int64_t temp_block_size = cuda::FindNumThreads(batch_size);
  int64_t temp_num_blocks = (batch_size - 1) / temp_block_size + 1;
528
529
530
  CUDA_KERNEL_CALL(
      GetNumBlockPerSegment, temp_num_blocks, temp_block_size, 0, stream,
      query_offsets_data, num_block_per_segment, batch_size, block_size);
531
532
  size_t prefix_temp_size = 0;
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(
533
534
      nullptr, prefix_temp_size, num_block_per_segment, num_block_prefixsum,
      batch_size, stream));
535
536
  void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(
537
538
      prefix_temp, prefix_temp_size, num_block_per_segment, num_block_prefixsum,
      batch_size, stream));
539
540
  device->FreeWorkspace(ctx, prefix_temp);

541
542
  int64_t num_blocks = 0, final_elem = 0,
          copyoffset = (batch_size - 1) * sizeof(IdType);
543
  device->CopyDataFromTo(
544
545
      num_block_prefixsum, copyoffset, &num_blocks, 0, sizeof(IdType), ctx,
      DGLContext{kDGLCPU, 0}, query_offsets->dtype);
546
  device->CopyDataFromTo(
547
548
      num_block_per_segment, copyoffset, &final_elem, 0, sizeof(IdType), ctx,
      DGLContext{kDGLCPU, 0}, query_offsets->dtype);
549
550
551
552
553
554
555
  num_blocks += final_elem;
  device->FreeWorkspace(ctx, num_block_per_segment);
  device->FreeWorkspace(ctx, num_block_prefixsum);

  // get batch id and local id in segment
  temp_block_size = cuda::FindNumThreads(num_blocks);
  temp_num_blocks = (num_blocks - 1) / temp_block_size + 1;
556
557
558
559
  IdType* block_batch_id = static_cast<IdType*>(
      device->AllocWorkspace(ctx, num_blocks * sizeof(IdType)));
  IdType* local_block_id = static_cast<IdType*>(
      device->AllocWorkspace(ctx, num_blocks * sizeof(IdType)));
560
  CUDA_KERNEL_CALL(
561
562
563
      GetBlockInfo, temp_num_blocks, temp_block_size, 0, stream,
      num_block_prefixsum, block_batch_id, local_block_id, batch_size,
      num_blocks);
564
565

  FloatType* dists = static_cast<FloatType*>(device->AllocWorkspace(
566
567
568
569
570
571
      ctx, k * query_points->shape[0] * sizeof(FloatType)));
  CUDA_KERNEL_CALL(
      BruteforceKnnShareKernel, num_blocks, block_size,
      single_shared_mem * block_size, stream, data_points_data,
      data_offsets_data, query_points_data, query_offsets_data, block_batch_id,
      local_block_id, k, dists, query_out, data_out, batch_size, feature_size);
572
573
574
575
576

  device->FreeWorkspace(ctx, dists);
  device->FreeWorkspace(ctx, local_block_id);
  device->FreeWorkspace(ctx, block_batch_id);
}
577

578
/** @brief Setup rng state for nn-descent */
579
580
__global__ void SetupRngKernel(
    curandState* states, const uint64_t seed, const size_t n) {
581
582
583
584
585
586
  size_t id = blockIdx.x * blockDim.x + threadIdx.x;
  if (id < n) {
    curand_init(seed, id, 0, states + id);
  }
}

587
/**
588
 * @brief Randomly initialize neighbors (sampling without replacement)
589
590
591
 * for each nodes
 */
template <typename FloatType, typename IdType>
592
593
594
595
__global__ void RandomInitNeighborsKernel(
    const FloatType* points, const IdType* offsets, IdType* central_nodes,
    IdType* neighbors, FloatType* dists, bool* flags, const int k,
    const int64_t feature_size, const int64_t batch_size, const uint64_t seed) {
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
  const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  IdType batch_idx = 0;
  if (point_idx >= offsets[batch_size]) return;
  curandState state;
  curand_init(seed, point_idx, 0, &state);

  // find the segment location in the input batch
  for (IdType b = 0; b < batch_size + 1; ++b) {
    if (offsets[b] > point_idx) {
      batch_idx = b - 1;
      break;
    }
  }

  const IdType segment_size = offsets[batch_idx + 1] - offsets[batch_idx];
  IdType* current_neighbors = neighbors + point_idx * k;
  IdType* current_central_nodes = central_nodes + point_idx * k;
  bool* current_flags = flags + point_idx * k;
  FloatType* current_dists = dists + point_idx * k;
  IdType segment_start = offsets[batch_idx];

  // reservoir sampling
  for (IdType i = 0; i < k; ++i) {
    current_neighbors[i] = i + segment_start;
    current_central_nodes[i] = point_idx;
  }
  for (IdType i = k; i < segment_size; ++i) {
    const IdType j = static_cast<IdType>(curand(&state) % (i + 1));
    if (j < k) current_neighbors[j] = i + segment_start;
  }

  // compute distances and set flags
  for (IdType i = 0; i < k; ++i) {
    current_flags[i] = true;
    current_dists[i] = EuclideanDist<FloatType, IdType>(
631
632
        points + point_idx * feature_size,
        points + current_neighbors[i] * feature_size, feature_size);
633
634
635
636
637
638
  }

  // build heap
  BuildHeap<FloatType, IdType>(neighbors + point_idx * k, current_dists, k);
}

639
/**
640
 * @brief Randomly select candidates from current knn and reverse-knn graph for
641
642
 *        nn-descent.
 */
643
template <typename IdType>
644
645
646
647
__global__ void FindCandidatesKernel(
    const IdType* offsets, IdType* new_candidates, IdType* old_candidates,
    IdType* neighbors, bool* flags, const uint64_t seed,
    const int64_t batch_size, const int num_candidates, const int k) {
648
649
650
651
652
653
654
655
656
657
658
659
660
661
  const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  IdType batch_idx = 0;
  if (point_idx >= offsets[batch_size]) return;
  curandState state;
  curand_init(seed, point_idx, 0, &state);

  // find the segment location in the input batch
  for (IdType b = 0; b < batch_size + 1; ++b) {
    if (offsets[b] > point_idx) {
      batch_idx = b - 1;
      break;
    }
  }

662
663
  IdType segment_start = offsets[batch_idx],
         segment_end = offsets[batch_idx + 1];
664
665
666
667
  IdType* current_neighbors = neighbors + point_idx * k;
  bool* current_flags = flags + point_idx * k;

  // reset candidates
668
669
670
671
  IdType* new_candidates_ptr =
      new_candidates + point_idx * (num_candidates + 1);
  IdType* old_candidates_ptr =
      old_candidates + point_idx * (num_candidates + 1);
672
673
674
675
676
677
678
  new_candidates_ptr[0] = 0;
  old_candidates_ptr[0] = 0;

  // select candidates from current knn graph
  // here we use candidate[0] for reservoir sampling temporarily
  for (IdType i = 0; i < k; ++i) {
    IdType candidate = current_neighbors[i];
679
680
    IdType* candidate_array =
        current_flags[i] ? new_candidates_ptr : old_candidates_ptr;
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
    IdType curr_num = candidate_array[0];
    IdType* candidate_data = candidate_array + 1;

    // reservoir sampling
    if (curr_num < num_candidates) {
      candidate_data[curr_num] = candidate;
    } else {
      IdType pos = static_cast<IdType>(curand(&state) % (curr_num + 1));
      if (pos < num_candidates) candidate_data[pos] = candidate;
    }
    ++candidate_array[0];
  }

  // select candidates from current reverse knn graph
  // here we use candidate[0] for reservoir sampling temporarily
  IdType index_start = segment_start * k, index_end = segment_end * k;
  for (IdType i = index_start; i < index_end; ++i) {
    if (neighbors[i] == point_idx) {
      IdType reverse_candidate = (i - index_start) / k + segment_start;
700
701
      IdType* candidate_array =
          flags[i] ? new_candidates_ptr : old_candidates_ptr;
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
      IdType curr_num = candidate_array[0];
      IdType* candidate_data = candidate_array + 1;

      // reservoir sampling
      if (curr_num < num_candidates) {
        candidate_data[curr_num] = reverse_candidate;
      } else {
        IdType pos = static_cast<IdType>(curand(&state) % (curr_num + 1));
        if (pos < num_candidates) candidate_data[pos] = reverse_candidate;
      }
      ++candidate_array[0];
    }
  }

  // set candidate[0] back to length
717
718
719
720
  if (new_candidates_ptr[0] > num_candidates)
    new_candidates_ptr[0] = num_candidates;
  if (old_candidates_ptr[0] > num_candidates)
    old_candidates_ptr[0] = num_candidates;
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737

  // mark new_candidates as old
  IdType num_new_candidates = new_candidates_ptr[0];
  for (IdType i = 0; i < k; ++i) {
    IdType neighbor_idx = current_neighbors[i];

    if (current_flags[i]) {
      for (IdType j = 1; j < num_new_candidates + 1; ++j) {
        if (new_candidates_ptr[j] == neighbor_idx) {
          current_flags[i] = false;
          break;
        }
      }
    }
  }
}

738
/** @brief Update knn graph according to selected candidates for nn-descent */
739
template <typename FloatType, typename IdType>
740
741
742
743
744
__global__ void UpdateNeighborsKernel(
    const FloatType* points, const IdType* offsets, IdType* neighbors,
    IdType* new_candidates, IdType* old_candidates, FloatType* distances,
    bool* flags, IdType* num_updates, const int64_t batch_size,
    const int num_candidates, const int k, const int64_t feature_size) {
745
746
747
748
749
  const IdType point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (point_idx >= offsets[batch_size]) return;
  IdType* current_neighbors = neighbors + point_idx * k;
  bool* current_flags = flags + point_idx * k;
  FloatType* current_dists = distances + point_idx * k;
750
751
752
753
  IdType* new_candidates_ptr =
      new_candidates + point_idx * (num_candidates + 1);
  IdType* old_candidates_ptr =
      old_candidates + point_idx * (num_candidates + 1);
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
  IdType num_new_candidates = new_candidates_ptr[0];
  IdType num_old_candidates = old_candidates_ptr[0];
  IdType current_num_updates = 0;

  // process new candidates
  for (IdType i = 1; i <= num_new_candidates; ++i) {
    IdType new_c = new_candidates_ptr[i];

    // new/old candidates of the current new candidate
    IdType* twohop_new_ptr = new_candidates + new_c * (num_candidates + 1);
    IdType* twohop_old_ptr = old_candidates + new_c * (num_candidates + 1);
    IdType num_twohop_new = twohop_new_ptr[0];
    IdType num_twohop_old = twohop_old_ptr[0];
    FloatType worst_dist = current_dists[0];

    // new - new
    for (IdType j = 1; j <= num_twohop_new; ++j) {
      IdType twohop_new_c = twohop_new_ptr[j];
      FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
773
774
          points + point_idx * feature_size,
          points + twohop_new_c * feature_size, feature_size, worst_dist);
775
776

      if (FlaggedHeapInsert<FloatType, IdType>(
777
778
779
780
              current_neighbors, current_dists, current_flags, twohop_new_c,
              new_dist, true, k, true)) {
        ++current_num_updates;
        worst_dist = current_dists[0];
781
782
783
784
785
786
787
      }
    }

    // new - old
    for (IdType j = 1; j <= num_twohop_old; ++j) {
      IdType twohop_old_c = twohop_old_ptr[j];
      FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
788
789
          points + point_idx * feature_size,
          points + twohop_old_c * feature_size, feature_size, worst_dist);
790
791

      if (FlaggedHeapInsert<FloatType, IdType>(
792
793
794
795
              current_neighbors, current_dists, current_flags, twohop_old_c,
              new_dist, true, k, true)) {
        ++current_num_updates;
        worst_dist = current_dists[0];
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
      }
    }
  }

  // process old candidates
  for (IdType i = 1; i <= num_old_candidates; ++i) {
    IdType old_c = old_candidates_ptr[i];

    // new candidates of the current old candidate
    IdType* twohop_new_ptr = new_candidates + old_c * (num_candidates + 1);
    IdType num_twohop_new = twohop_new_ptr[0];
    FloatType worst_dist = current_dists[0];

    // old - new
    for (IdType j = 1; j <= num_twohop_new; ++j) {
      IdType twohop_new_c = twohop_new_ptr[j];
      FloatType new_dist = EuclideanDistWithCheck<FloatType, IdType>(
813
814
          points + point_idx * feature_size,
          points + twohop_new_c * feature_size, feature_size, worst_dist);
815
816

      if (FlaggedHeapInsert<FloatType, IdType>(
817
818
819
820
              current_neighbors, current_dists, current_flags, twohop_new_c,
              new_dist, true, k, true)) {
        ++current_num_updates;
        worst_dist = current_dists[0];
821
822
823
824
825
826
827
      }
    }
  }

  num_updates[point_idx] = current_num_updates;
}

828
829
}  // namespace impl

830
template <DGLDeviceType XPU, typename FloatType, typename IdType>
831
832
833
834
void KNN(
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm) {
835
836
  if (algorithm == std::string("bruteforce")) {
    impl::BruteForceKNNCuda<FloatType, IdType>(
837
        data_points, data_offsets, query_points, query_offsets, k, result);
838
839
  } else if (algorithm == std::string("bruteforce-sharemem")) {
    impl::BruteForceKNNSharedCuda<FloatType, IdType>(
840
        data_points, data_offsets, query_points, query_offsets, k, result);
841
842
843
844
845
  } else {
    LOG(FATAL) << "Algorithm " << algorithm << " is not supported on CUDA.";
  }
}

846
template <DGLDeviceType XPU, typename FloatType, typename IdType>
847
848
849
void NNDescent(
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta) {
850
  cudaStream_t stream = runtime::getCurrentCUDAStream();
851
852
853
854
855
856
857
858
859
860
861
862
  const auto& ctx = points->ctx;
  auto device = runtime::DeviceAPI::Get(ctx);
  const int64_t num_nodes = points->shape[0];
  const int64_t feature_size = points->shape[1];
  const int64_t batch_size = offsets->shape[0] - 1;
  const IdType* offsets_data = offsets.Ptr<IdType>();
  const FloatType* points_data = points.Ptr<FloatType>();

  IdType* central_nodes = result.Ptr<IdType>();
  IdType* neighbors = central_nodes + k * num_nodes;
  uint64_t seed;
  int warp_size = 0;
863
864
865
866
  CUDA_CALL(
      cudaDeviceGetAttribute(&warp_size, cudaDevAttrWarpSize, ctx.device_id));
  // We don't need large block sizes, since there's not much inter-thread
  // communication
867
868
869
870
871
  int64_t block_size = warp_size;
  int64_t num_blocks = (num_nodes - 1) / block_size + 1;

  // allocate space for candidates, distances and flags
  // we use the first element in candidate array to represent length
872
873
874
875
  IdType* new_candidates = static_cast<IdType*>(device->AllocWorkspace(
      ctx, num_nodes * (num_candidates + 1) * sizeof(IdType)));
  IdType* old_candidates = static_cast<IdType*>(device->AllocWorkspace(
      ctx, num_nodes * (num_candidates + 1) * sizeof(IdType)));
876
  IdType* num_updates = static_cast<IdType*>(
877
      device->AllocWorkspace(ctx, num_nodes * sizeof(IdType)));
878
  FloatType* distances = static_cast<FloatType*>(
879
      device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType)));
880
  bool* flags = static_cast<bool*>(
881
      device->AllocWorkspace(ctx, num_nodes * k * sizeof(IdType)));
882
883
884

  size_t sum_temp_size = 0;
  IdType total_num_updates = 0;
885
886
  IdType* total_num_updates_d =
      static_cast<IdType*>(device->AllocWorkspace(ctx, sizeof(IdType)));
887
888

  CUDA_CALL(cub::DeviceReduce::Sum(
889
890
891
892
      nullptr, sum_temp_size, num_updates, total_num_updates_d, num_nodes,
      stream));
  IdType* sum_temp_storage =
      static_cast<IdType*>(device->AllocWorkspace(ctx, sum_temp_size));
893
894
895

  // random initialize neighbors
  seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(
896
      std::numeric_limits<uint64_t>::max());
897
  CUDA_KERNEL_CALL(
898
899
900
      impl::RandomInitNeighborsKernel, num_blocks, block_size, 0, stream,
      points_data, offsets_data, central_nodes, neighbors, distances, flags, k,
      feature_size, batch_size, seed);
901
902
903
904

  for (int i = 0; i < num_iters; ++i) {
    // select candidates
    seed = RandomEngine::ThreadLocal()->RandInt<uint64_t>(
905
        std::numeric_limits<uint64_t>::max());
906
    CUDA_KERNEL_CALL(
907
908
909
        impl::FindCandidatesKernel, num_blocks, block_size, 0, stream,
        offsets_data, new_candidates, old_candidates, neighbors, flags, seed,
        batch_size, num_candidates, k);
910
911
912

    // update
    CUDA_KERNEL_CALL(
913
914
915
916
        impl::UpdateNeighborsKernel, num_blocks, block_size, 0, stream,
        points_data, offsets_data, neighbors, new_candidates, old_candidates,
        distances, flags, num_updates, batch_size, num_candidates, k,
        feature_size);
917
918
919

    total_num_updates = 0;
    CUDA_CALL(cub::DeviceReduce::Sum(
920
921
        sum_temp_storage, sum_temp_size, num_updates, total_num_updates_d,
        num_nodes, stream));
922
    device->CopyDataFromTo(
923
924
        total_num_updates_d, 0, &total_num_updates, 0, sizeof(IdType), ctx,
        DGLContext{kDGLCPU, 0}, offsets->dtype);
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939

    if (total_num_updates <= static_cast<IdType>(delta * k * num_nodes)) {
      break;
    }
  }

  device->FreeWorkspace(ctx, new_candidates);
  device->FreeWorkspace(ctx, old_candidates);
  device->FreeWorkspace(ctx, num_updates);
  device->FreeWorkspace(ctx, distances);
  device->FreeWorkspace(ctx, flags);
  device->FreeWorkspace(ctx, total_num_updates_d);
  device->FreeWorkspace(ctx, sum_temp_storage);
}

940
template void KNN<kDGLCUDA, float, int32_t>(
941
942
943
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
944
template void KNN<kDGLCUDA, float, int64_t>(
945
946
947
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
948
template void KNN<kDGLCUDA, double, int32_t>(
949
950
951
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
952
template void KNN<kDGLCUDA, double, int64_t>(
953
954
955
    const NDArray& data_points, const IdArray& data_offsets,
    const NDArray& query_points, const IdArray& query_offsets, const int k,
    IdArray result, const std::string& algorithm);
956

957
template void NNDescent<kDGLCUDA, float, int32_t>(
958
959
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
960
template void NNDescent<kDGLCUDA, float, int64_t>(
961
962
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
963
template void NNDescent<kDGLCUDA, double, int32_t>(
964
965
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
966
template void NNDescent<kDGLCUDA, double, int64_t>(
967
968
    const NDArray& points, const IdArray& offsets, IdArray result, const int k,
    const int num_iters, const int num_candidates, const double delta);
969

970
971
}  // namespace transform
}  // namespace dgl