partition_op.cu 21.3 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file ndarray_partition.h
 * @brief Operations on partition implemented in CUDA.
5
6
7
8
 */

#include <dgl/runtime/device_api.h>

9
10
#include <cub/cub.cuh>

11
12
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/workspace.h"
13
#include "../partition_op.h"
14
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {
namespace partition {
namespace impl {

21
22
23
namespace {

/**
24
25
26
27
28
29
30
31
32
 * @brief Kernel to map global element IDs to partition IDs by remainder.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param num_parts The number of partitions.
 * @param part_id The mapped partition ID (outupt).
 */
template <typename IdType>
33
__global__ void _MapProcByRemainderKernel(
34
35
36
37
38
    const IdType* const global, const int64_t num_elements,
    const int64_t num_parts, IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
39

40
41
  if (idx < num_elements) {
    part_id[idx] = global[idx] % num_parts;
42
43
44
  }
}

45
/**
46
47
48
49
50
51
52
53
54
55
56
 * @brief Kernel to map global element IDs to partition IDs, using a bit-mask.
 * The number of partitions must be a power a two.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param mask The bit-mask with 1's for each bit to keep from the element ID to
 * extract the partition ID (e.g., an 8 partition mask would be 0x07).
 * @param part_id The mapped partition ID (outupt).
 */
template <typename IdType>
57
__global__ void _MapProcByMaskRemainderKernel(
58
59
60
61
62
    const IdType* const global, const int64_t num_elements, const IdType mask,
    IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
63

64
65
  if (idx < num_elements) {
    part_id[idx] = global[idx] & mask;
66
67
68
  }
}

69
/**
70
71
72
73
74
75
76
77
78
 * @brief Kernel to map global element IDs to local element IDs.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of IDs.
 * @param num_parts The number of partitions.
 * @param local The local element IDs (output).
 */
template <typename IdType>
79
__global__ void _MapLocalIndexByRemainderKernel(
80
81
82
83
    const IdType* const global, const int64_t num_elements, const int num_parts,
    IdType* const local) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
84

85
86
  if (idx < num_elements) {
    local[idx] = global[idx] / num_parts;
87
88
89
  }
}

90
/**
91
92
93
94
95
96
97
98
99
100
101
 * @brief Kernel to map local element IDs within a partition to their global
 * IDs, using the remainder over the number of partitions.
 *
 * @tparam IdType The type of ID.
 * @param local The local element IDs.
 * @param part_id The partition to map local elements from.
 * @param num_elements The number of elements to map.
 * @param num_parts The number of partitions.
 * @param global The global element IDs (output).
 */
template <typename IdType>
102
__global__ void _MapGlobalIndexByRemainderKernel(
103
104
105
106
    const IdType* const local, const int part_id, const int64_t num_elements,
    const int num_parts, IdType* const global) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
107
108
109
110
111
112
113
114
115

  assert(part_id < num_parts);

  if (idx < num_elements) {
    global[idx] = (local[idx] * num_parts) + part_id;
  }
}

/**
116
117
118
119
120
121
122
123
124
125
126
 * @brief Device function to perform a binary search to find to which partition
 * a given ID belongs.
 *
 * @tparam RangeType The type of range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param num_parts The number of partitions.
 * @param target The element ID to find the partition of.
 *
 * @return The partition.
 */
template <typename RangeType>
127
__device__ RangeType _SearchRange(
128
    const RangeType* const range, const int num_parts, const RangeType target) {
129
130
  int start = 0;
  int end = num_parts;
131
  int cur = (end + start) / 2;
132
133
134
135

  assert(range[0] == 0);
  assert(target < range[num_parts]);

136
  while (start + 1 < end) {
137
138
139
140
141
    if (target < range[cur]) {
      end = cur;
    } else {
      start = cur;
    }
142
    cur = (start + end) / 2;
143
144
145
146
147
148
  }

  return cur;
}

/**
149
150
151
152
153
154
155
156
157
158
159
 * @brief Kernel to map element IDs to partition IDs.
 *
 * @tparam IdType The type of element ID.
 * @tparam RangeType The type of of the range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param num_parts The number of partitions.
 * @param part_id The partition ID assigned to each element (output).
 */
template <typename IdType, typename RangeType>
160
__global__ void _MapProcByRangeKernel(
161
162
163
164
165
166
    const RangeType* const range, const IdType* const global,
    const int64_t num_elements, const int64_t num_parts,
    IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
167
168
169
170

  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    part_id[idx] = static_cast<IdType>(_SearchRange(
171
        range, static_cast<int>(num_parts),
172
173
174
175
176
        static_cast<RangeType>(global[idx])));
  }
}

/**
177
178
179
180
181
182
183
184
185
186
187
188
 * @brief Kernel to map global element IDs to their ID within their respective
 * partition.
 *
 * @tparam IdType The type of element ID.
 * @tparam RangeType The type of the range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param global The global element IDs.
 * @param num_elements The number of elements.
 * @param num_parts The number of partitions.
 * @param local The local element IDs (output).
 */
template <typename IdType, typename RangeType>
189
__global__ void _MapLocalIndexByRangeKernel(
190
191
192
193
    const RangeType* const range, const IdType* const global,
    const int64_t num_elements, const int num_parts, IdType* const local) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
194
195
196
197

  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    const int proc = _SearchRange(
198
        range, static_cast<int>(num_parts),
199
200
201
202
203
204
        static_cast<RangeType>(global[idx]));
    local[idx] = global[idx] - range[proc];
  }
}

/**
205
206
207
208
209
210
211
212
213
214
215
216
217
 * @brief Kernel to map local element IDs within a partition to their global
 * IDs.
 *
 * @tparam IdType The type of ID.
 * @tparam RangeType The type of the range.
 * @param range The prefix-sum of IDs assigend to partitions.
 * @param local The local element IDs.
 * @param part_id The partition to map local elements from.
 * @param num_elements The number of elements to map.
 * @param num_parts The number of partitions.
 * @param global The global element IDs (output).
 */
template <typename IdType, typename RangeType>
218
__global__ void _MapGlobalIndexByRangeKernel(
219
220
221
222
    const RangeType* const range, const IdType* const local, const int part_id,
    const int64_t num_elements, const int num_parts, IdType* const global) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
223

224
  assert(part_id < num_parts);
225

226
227
228
  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    global[idx] = local[idx] + range[part_id];
229
230
  }
}
231
232
233
}  // namespace

// Remainder Based Partition Operations
234

235
template <DGLDeviceType XPU, typename IdType>
236
237
std::pair<IdArray, NDArray> GeneratePermutationFromRemainder(
    int64_t array_size, int num_parts, IdArray in_idx) {
238
239
240
241
  std::pair<IdArray, NDArray> result;

  const auto& ctx = in_idx->ctx;
  auto device = DeviceAPI::Get(ctx);
242
  cudaStream_t stream = runtime::getCurrentCUDAStream();
243
244
245

  const int64_t num_in = in_idx->shape[0];

246
247
  CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
                         << ") must be at least 1.";
248
249
  if (num_parts == 1) {
    // no permutation
250
251
    result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
    result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
252
253
254
255

    return result;
  }

256
257
258
  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
  result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
  int64_t* out_counts = static_cast<int64_t*>(result.second->data);
259
260
261
262
263
264
265
266
267
268
269
270
271
  if (num_in == 0) {
    // now that we've zero'd out_counts, nothing left to do for an empty
    // mapping
    return result;
  }

  const int64_t part_bits =
      static_cast<int64_t>(std::ceil(std::log2(num_parts)));

  // First, generate a mapping of indexes to processors
  Workspace<IdType> proc_id_in(device, ctx, num_in);
  {
    const dim3 block(256);
272
    const dim3 grid((num_in + block.x - 1) / block.x);
273
274
275

    if (num_parts < (1 << part_bits)) {
      // num_parts is not a power of 2
276
277
278
      CUDA_KERNEL_CALL(
          _MapProcByRemainderKernel, grid, block, 0, stream,
          static_cast<const IdType*>(in_idx->data), num_in, num_parts,
279
280
281
          proc_id_in.get());
    } else {
      // num_parts is a power of 2
282
283
284
285
      CUDA_KERNEL_CALL(
          _MapProcByMaskRemainderKernel, grid, block, 0, stream,
          static_cast<const IdType*>(in_idx->data), num_in,
          static_cast<IdType>(num_parts - 1),  // bit mask
286
287
288
289
290
291
292
          proc_id_in.get());
    }
  }

  // then create a permutation array that groups processors together by
  // performing a radix sort
  Workspace<IdType> proc_id_out(device, ctx, num_in);
293
  IdType* perm_out = static_cast<IdType*>(result.first->data);
294
  {
295
    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
296
297

    size_t sort_workspace_size;
298
299
300
301
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(
        nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
        static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
        stream));
302
303

    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
304
305
306
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(
        sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
        proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
307
308
309
310
311
312
313
314
315
        num_in, 0, part_bits, stream));
  }
  // explicitly free so workspace can be re-used
  proc_id_in.free();

  // perform a histogram and then prefixsum on the sorted proc_id vector

  // Count the number of values to be sent to each processor
  {
316
317
318
    using AtomicCount = unsigned long long;  // NOLINT
    static_assert(
        sizeof(AtomicCount) == sizeof(*out_counts),
319
320
321
322
323
324
        "AtomicCount must be the same width as int64_t for atomicAdd "
        "in cub::DeviceHistogram::HistogramEven() to work");

    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
    // add a compile time check against the cub version to allow
    // num_in > (2 << 31).
325
326
327
    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
        << "number of values to insert into histogram must be less than max "
           "value of int.";
328
329
330

    size_t hist_workspace_size;
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
331
332
        nullptr, hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
333
        static_cast<IdType>(0), static_cast<IdType>(num_parts),
334
        static_cast<int>(num_in), stream));
335
336
337

    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
338
339
        hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
340
        static_cast<IdType>(0), static_cast<IdType>(num_parts),
341
        static_cast<int>(num_in), stream));
342
343
344
345
346
  }

  return result;
}

347
348
349
350
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
    kDGLCUDA, int32_t>(int64_t array_size, int num_parts, IdArray in_idx);
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
    kDGLCUDA, int64_t>(int64_t array_size, int num_parts, IdArray in_idx);
351

352
template <DGLDeviceType XPU, typename IdType>
353
IdArray MapToLocalFromRemainder(const int num_parts, IdArray global_idx) {
354
  const auto& ctx = global_idx->ctx;
355
  cudaStream_t stream = runtime::getCurrentCUDAStream();
356

357
  if (num_parts > 1) {
358
359
    IdArray local_idx =
        aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
360
361

    const dim3 block(128);
362
    const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
363
364

    CUDA_KERNEL_CALL(
365
366
367
        _MapLocalIndexByRemainderKernel, grid, block, 0, stream,
        static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
        num_parts, static_cast<IdType*>(local_idx->data));
368
369
370
371
372
373

    return local_idx;
  } else {
    // no mapping to be done
    return global_idx;
  }
374
375
}

376
377
378
379
template IdArray MapToLocalFromRemainder<kDGLCUDA, int32_t>(
    int num_parts, IdArray in_idx);
template IdArray MapToLocalFromRemainder<kDGLCUDA, int64_t>(
    int num_parts, IdArray in_idx);
380

381
template <DGLDeviceType XPU, typename IdType>
382
IdArray MapToGlobalFromRemainder(
383
384
385
386
387
    const int num_parts, IdArray local_idx, const int part_id) {
  CHECK_LT(part_id, num_parts)
      << "Invalid partition id " << part_id << "/" << num_parts;
  CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
                       << num_parts;
388
389

  const auto& ctx = local_idx->ctx;
390
  cudaStream_t stream = runtime::getCurrentCUDAStream();
391
392

  if (num_parts > 1) {
393
394
    IdArray global_idx =
        aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
395
396

    const dim3 block(128);
397
    const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
398
399

    CUDA_KERNEL_CALL(
400
401
402
        _MapGlobalIndexByRemainderKernel, grid, block, 0, stream,
        static_cast<const IdType*>(local_idx->data), part_id,
        global_idx->shape[0], num_parts,
403
        static_cast<IdType*>(global_idx->data));
404
405
406
407
408
409
410
411

    return global_idx;
  } else {
    // no mapping to be done
    return local_idx;
  }
}

412
413
414
415
template IdArray MapToGlobalFromRemainder<kDGLCUDA, int32_t>(
    int num_parts, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRemainder<kDGLCUDA, int64_t>(
    int num_parts, IdArray in_idx, int part_id);
416

417
418
// Range Based Partition Operations

419
template <DGLDeviceType XPU, typename IdType, typename RangeType>
420
421
std::pair<IdArray, NDArray> GeneratePermutationFromRange(
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx) {
422
423
424
425
  std::pair<IdArray, NDArray> result;

  const auto& ctx = in_idx->ctx;
  auto device = DeviceAPI::Get(ctx);
426
  cudaStream_t stream = runtime::getCurrentCUDAStream();
427
428
429

  const int64_t num_in = in_idx->shape[0];

430
431
  CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
                         << ") must be at least 1.";
432
433
  if (num_parts == 1) {
    // no permutation
434
435
    result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
    result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
436
437
438
439

    return result;
  }

440
441
442
  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
  result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
  int64_t* out_counts = static_cast<int64_t*>(result.second->data);
443
444
445
446
447
448
449
450
451
452
453
454
455
  if (num_in == 0) {
    // now that we've zero'd out_counts, nothing left to do for an empty
    // mapping
    return result;
  }

  const int64_t part_bits =
      static_cast<int64_t>(std::ceil(std::log2(num_parts)));

  // First, generate a mapping of indexes to processors
  Workspace<IdType> proc_id_in(device, ctx, num_in);
  {
    const dim3 block(256);
456
    const dim3 grid((num_in + block.x - 1) / block.x);
457

458
459
    CUDA_KERNEL_CALL(
        _MapProcByRangeKernel, grid, block, 0, stream,
460
        static_cast<const RangeType*>(range->data),
461
        static_cast<const IdType*>(in_idx->data), num_in, num_parts,
462
463
464
465
466
467
        proc_id_in.get());
  }

  // then create a permutation array that groups processors together by
  // performing a radix sort
  Workspace<IdType> proc_id_out(device, ctx, num_in);
468
  IdType* perm_out = static_cast<IdType*>(result.first->data);
469
  {
470
    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
471
472

    size_t sort_workspace_size;
473
474
475
476
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(
        nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
        static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
        stream));
477
478

    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
479
480
481
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(
        sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
        proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
482
483
484
485
486
487
488
489
490
        num_in, 0, part_bits, stream));
  }
  // explicitly free so workspace can be re-used
  proc_id_in.free();

  // perform a histogram and then prefixsum on the sorted proc_id vector

  // Count the number of values to be sent to each processor
  {
491
492
493
    using AtomicCount = unsigned long long;  // NOLINT
    static_assert(
        sizeof(AtomicCount) == sizeof(*out_counts),
494
495
496
497
498
499
        "AtomicCount must be the same width as int64_t for atomicAdd "
        "in cub::DeviceHistogram::HistogramEven() to work");

    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
    // add a compile time check against the cub version to allow
    // num_in > (2 << 31).
500
501
502
    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
        << "number of values to insert into histogram must be less than max "
           "value of int.";
503
504
505

    size_t hist_workspace_size;
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
506
507
        nullptr, hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
508
        static_cast<IdType>(0), static_cast<IdType>(num_parts),
509
        static_cast<int>(num_in), stream));
510
511
512

    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
513
514
        hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
515
        static_cast<IdType>(0), static_cast<IdType>(num_parts),
516
        static_cast<int>(num_in), stream));
517
518
519
520
521
522
  }

  return result;
}

template std::pair<IdArray, IdArray>
523
GeneratePermutationFromRange<kDGLCUDA, int32_t, int32_t>(
524
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
525
template std::pair<IdArray, IdArray>
526
GeneratePermutationFromRange<kDGLCUDA, int64_t, int32_t>(
527
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
528
template std::pair<IdArray, IdArray>
529
GeneratePermutationFromRange<kDGLCUDA, int32_t, int64_t>(
530
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
531
template std::pair<IdArray, IdArray>
532
GeneratePermutationFromRange<kDGLCUDA, int64_t, int64_t>(
533
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
534

535
template <DGLDeviceType XPU, typename IdType, typename RangeType>
536
IdArray MapToLocalFromRange(
537
    const int num_parts, IdArray range, IdArray global_idx) {
538
  const auto& ctx = global_idx->ctx;
539
  cudaStream_t stream = runtime::getCurrentCUDAStream();
540
541

  if (num_parts > 1 && global_idx->shape[0] > 0) {
542
543
    IdArray local_idx =
        aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
544
545

    const dim3 block(128);
546
    const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
547
548

    CUDA_KERNEL_CALL(
549
        _MapLocalIndexByRangeKernel, grid, block, 0, stream,
550
        static_cast<const RangeType*>(range->data),
551
552
        static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
        num_parts, static_cast<IdType*>(local_idx->data));
553
554
555
556
557
558
559
560

    return local_idx;
  } else {
    // no mapping to be done
    return global_idx;
  }
}

561
562
563
564
565
566
567
568
template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx);
569

570
template <DGLDeviceType XPU, typename IdType, typename RangeType>
571
IdArray MapToGlobalFromRange(
572
573
574
575
576
    const int num_parts, IdArray range, IdArray local_idx, const int part_id) {
  CHECK_LT(part_id, num_parts)
      << "Invalid partition id " << part_id << "/" << num_parts;
  CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
                       << num_parts;
577
578

  const auto& ctx = local_idx->ctx;
579
  cudaStream_t stream = runtime::getCurrentCUDAStream();
580
581

  if (num_parts > 1 && local_idx->shape[0] > 0) {
582
583
    IdArray global_idx =
        aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
584
585

    const dim3 block(128);
586
    const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
587
588

    CUDA_KERNEL_CALL(
589
        _MapGlobalIndexByRangeKernel, grid, block, 0, stream,
590
        static_cast<const RangeType*>(range->data),
591
592
        static_cast<const IdType*>(local_idx->data), part_id,
        global_idx->shape[0], num_parts,
593
594
595
596
597
598
599
600
601
        static_cast<IdType*>(global_idx->data));

    return global_idx;
  } else {
    // no mapping to be done
    return local_idx;
  }
}

602
603
604
605
606
607
608
609
template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
610
611
612
613

}  // namespace impl
}  // namespace partition
}  // namespace dgl