partition_op.cu 21.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 *  Copyright (c) 2021 by Contributors
 * \file ndarray_partition.h
 * \brief Operations on partition implemented in CUDA.
 */

#include <dgl/runtime/device_api.h>

#include "../../array/cuda/dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/workspace.h"
12
#include "../partition_op.h"
13
14
15
16
17
18
19

using namespace dgl::runtime;

namespace dgl {
namespace partition {
namespace impl {

20
21
22
namespace {

/**
23
24
25
26
27
28
29
30
31
 * @brief Kernel to map global element IDs to partition IDs by remainder.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param num_parts The number of partitions.
 * @param part_id The mapped partition ID (outupt).
 */
template <typename IdType>
32
__global__ void _MapProcByRemainderKernel(
33
34
35
36
37
    const IdType* const global, const int64_t num_elements,
    const int64_t num_parts, IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
38

39
40
  if (idx < num_elements) {
    part_id[idx] = global[idx] % num_parts;
41
42
43
  }
}

44
/**
45
46
47
48
49
50
51
52
53
54
55
 * @brief Kernel to map global element IDs to partition IDs, using a bit-mask.
 * The number of partitions must be a power a two.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param mask The bit-mask with 1's for each bit to keep from the element ID to
 * extract the partition ID (e.g., an 8 partition mask would be 0x07).
 * @param part_id The mapped partition ID (outupt).
 */
template <typename IdType>
56
__global__ void _MapProcByMaskRemainderKernel(
57
58
59
60
61
    const IdType* const global, const int64_t num_elements, const IdType mask,
    IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
62

63
64
  if (idx < num_elements) {
    part_id[idx] = global[idx] & mask;
65
66
67
  }
}

68
/**
69
70
71
72
73
74
75
76
77
 * @brief Kernel to map global element IDs to local element IDs.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of IDs.
 * @param num_parts The number of partitions.
 * @param local The local element IDs (output).
 */
template <typename IdType>
78
__global__ void _MapLocalIndexByRemainderKernel(
79
80
81
82
    const IdType* const global, const int64_t num_elements, const int num_parts,
    IdType* const local) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
83

84
85
  if (idx < num_elements) {
    local[idx] = global[idx] / num_parts;
86
87
88
  }
}

89
/**
90
91
92
93
94
95
96
97
98
99
100
 * @brief Kernel to map local element IDs within a partition to their global
 * IDs, using the remainder over the number of partitions.
 *
 * @tparam IdType The type of ID.
 * @param local The local element IDs.
 * @param part_id The partition to map local elements from.
 * @param num_elements The number of elements to map.
 * @param num_parts The number of partitions.
 * @param global The global element IDs (output).
 */
template <typename IdType>
101
__global__ void _MapGlobalIndexByRemainderKernel(
102
103
104
105
    const IdType* const local, const int part_id, const int64_t num_elements,
    const int num_parts, IdType* const global) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
106
107
108
109
110
111
112
113
114

  assert(part_id < num_parts);

  if (idx < num_elements) {
    global[idx] = (local[idx] * num_parts) + part_id;
  }
}

/**
115
116
117
118
119
120
121
122
123
124
125
 * @brief Device function to perform a binary search to find to which partition
 * a given ID belongs.
 *
 * @tparam RangeType The type of range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param num_parts The number of partitions.
 * @param target The element ID to find the partition of.
 *
 * @return The partition.
 */
template <typename RangeType>
126
__device__ RangeType _SearchRange(
127
    const RangeType* const range, const int num_parts, const RangeType target) {
128
129
  int start = 0;
  int end = num_parts;
130
  int cur = (end + start) / 2;
131
132
133
134

  assert(range[0] == 0);
  assert(target < range[num_parts]);

135
  while (start + 1 < end) {
136
137
138
139
140
    if (target < range[cur]) {
      end = cur;
    } else {
      start = cur;
    }
141
    cur = (start + end) / 2;
142
143
144
145
146
147
  }

  return cur;
}

/**
148
149
150
151
152
153
154
155
156
157
158
 * @brief Kernel to map element IDs to partition IDs.
 *
 * @tparam IdType The type of element ID.
 * @tparam RangeType The type of of the range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param num_parts The number of partitions.
 * @param part_id The partition ID assigned to each element (output).
 */
template <typename IdType, typename RangeType>
159
__global__ void _MapProcByRangeKernel(
160
161
162
163
164
165
    const RangeType* const range, const IdType* const global,
    const int64_t num_elements, const int64_t num_parts,
    IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
166
167
168
169

  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    part_id[idx] = static_cast<IdType>(_SearchRange(
170
        range, static_cast<int>(num_parts),
171
172
173
174
175
        static_cast<RangeType>(global[idx])));
  }
}

/**
176
177
178
179
180
181
182
183
184
185
186
187
 * @brief Kernel to map global element IDs to their ID within their respective
 * partition.
 *
 * @tparam IdType The type of element ID.
 * @tparam RangeType The type of the range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param global The global element IDs.
 * @param num_elements The number of elements.
 * @param num_parts The number of partitions.
 * @param local The local element IDs (output).
 */
template <typename IdType, typename RangeType>
188
__global__ void _MapLocalIndexByRangeKernel(
189
190
191
192
    const RangeType* const range, const IdType* const global,
    const int64_t num_elements, const int num_parts, IdType* const local) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
193
194
195
196

  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    const int proc = _SearchRange(
197
        range, static_cast<int>(num_parts),
198
199
200
201
202
203
        static_cast<RangeType>(global[idx]));
    local[idx] = global[idx] - range[proc];
  }
}

/**
204
205
206
207
208
209
210
211
212
213
214
215
216
 * @brief Kernel to map local element IDs within a partition to their global
 * IDs.
 *
 * @tparam IdType The type of ID.
 * @tparam RangeType The type of the range.
 * @param range The prefix-sum of IDs assigend to partitions.
 * @param local The local element IDs.
 * @param part_id The partition to map local elements from.
 * @param num_elements The number of elements to map.
 * @param num_parts The number of partitions.
 * @param global The global element IDs (output).
 */
template <typename IdType, typename RangeType>
217
__global__ void _MapGlobalIndexByRangeKernel(
218
219
220
221
    const RangeType* const range, const IdType* const local, const int part_id,
    const int64_t num_elements, const int num_parts, IdType* const global) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
222

223
  assert(part_id < num_parts);
224

225
226
227
  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    global[idx] = local[idx] + range[part_id];
228
229
  }
}
230
231
232
}  // namespace

// Remainder Based Partition Operations
233

234
template <DGLDeviceType XPU, typename IdType>
235
236
std::pair<IdArray, NDArray> GeneratePermutationFromRemainder(
    int64_t array_size, int num_parts, IdArray in_idx) {
237
238
239
240
  std::pair<IdArray, NDArray> result;

  const auto& ctx = in_idx->ctx;
  auto device = DeviceAPI::Get(ctx);
241
  cudaStream_t stream = runtime::getCurrentCUDAStream();
242
243
244

  const int64_t num_in = in_idx->shape[0];

245
246
  CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
                         << ") must be at least 1.";
247
248
  if (num_parts == 1) {
    // no permutation
249
250
    result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
    result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
251
252
253
254

    return result;
  }

255
256
257
  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
  result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
  int64_t* out_counts = static_cast<int64_t*>(result.second->data);
258
259
260
261
262
263
264
265
266
267
268
269
270
  if (num_in == 0) {
    // now that we've zero'd out_counts, nothing left to do for an empty
    // mapping
    return result;
  }

  const int64_t part_bits =
      static_cast<int64_t>(std::ceil(std::log2(num_parts)));

  // First, generate a mapping of indexes to processors
  Workspace<IdType> proc_id_in(device, ctx, num_in);
  {
    const dim3 block(256);
271
    const dim3 grid((num_in + block.x - 1) / block.x);
272
273
274

    if (num_parts < (1 << part_bits)) {
      // num_parts is not a power of 2
275
276
277
      CUDA_KERNEL_CALL(
          _MapProcByRemainderKernel, grid, block, 0, stream,
          static_cast<const IdType*>(in_idx->data), num_in, num_parts,
278
279
280
          proc_id_in.get());
    } else {
      // num_parts is a power of 2
281
282
283
284
      CUDA_KERNEL_CALL(
          _MapProcByMaskRemainderKernel, grid, block, 0, stream,
          static_cast<const IdType*>(in_idx->data), num_in,
          static_cast<IdType>(num_parts - 1),  // bit mask
285
286
287
288
289
290
291
          proc_id_in.get());
    }
  }

  // then create a permutation array that groups processors together by
  // performing a radix sort
  Workspace<IdType> proc_id_out(device, ctx, num_in);
292
  IdType* perm_out = static_cast<IdType*>(result.first->data);
293
  {
294
    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
295
296

    size_t sort_workspace_size;
297
298
299
300
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(
        nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
        static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
        stream));
301
302

    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
303
304
305
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(
        sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
        proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
306
307
308
309
310
311
312
313
314
        num_in, 0, part_bits, stream));
  }
  // explicitly free so workspace can be re-used
  proc_id_in.free();

  // perform a histogram and then prefixsum on the sorted proc_id vector

  // Count the number of values to be sent to each processor
  {
315
316
317
    using AtomicCount = unsigned long long;  // NOLINT
    static_assert(
        sizeof(AtomicCount) == sizeof(*out_counts),
318
319
320
321
322
323
        "AtomicCount must be the same width as int64_t for atomicAdd "
        "in cub::DeviceHistogram::HistogramEven() to work");

    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
    // add a compile time check against the cub version to allow
    // num_in > (2 << 31).
324
325
326
    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
        << "number of values to insert into histogram must be less than max "
           "value of int.";
327
328
329

    size_t hist_workspace_size;
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
330
331
332
333
        nullptr, hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
        static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
        static_cast<int>(num_in), stream));
334
335
336

    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
337
338
339
340
        hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
        static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
        static_cast<int>(num_in), stream));
341
342
343
344
345
  }

  return result;
}

346
347
348
349
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
    kDGLCUDA, int32_t>(int64_t array_size, int num_parts, IdArray in_idx);
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
    kDGLCUDA, int64_t>(int64_t array_size, int num_parts, IdArray in_idx);
350

351
template <DGLDeviceType XPU, typename IdType>
352
IdArray MapToLocalFromRemainder(const int num_parts, IdArray global_idx) {
353
  const auto& ctx = global_idx->ctx;
354
  cudaStream_t stream = runtime::getCurrentCUDAStream();
355

356
  if (num_parts > 1) {
357
358
    IdArray local_idx =
        aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
359
360

    const dim3 block(128);
361
    const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
362
363

    CUDA_KERNEL_CALL(
364
365
366
        _MapLocalIndexByRemainderKernel, grid, block, 0, stream,
        static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
        num_parts, static_cast<IdType*>(local_idx->data));
367
368
369
370
371
372

    return local_idx;
  } else {
    // no mapping to be done
    return global_idx;
  }
373
374
}

375
376
377
378
template IdArray MapToLocalFromRemainder<kDGLCUDA, int32_t>(
    int num_parts, IdArray in_idx);
template IdArray MapToLocalFromRemainder<kDGLCUDA, int64_t>(
    int num_parts, IdArray in_idx);
379

380
template <DGLDeviceType XPU, typename IdType>
381
IdArray MapToGlobalFromRemainder(
382
383
384
385
386
    const int num_parts, IdArray local_idx, const int part_id) {
  CHECK_LT(part_id, num_parts)
      << "Invalid partition id " << part_id << "/" << num_parts;
  CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
                       << num_parts;
387
388

  const auto& ctx = local_idx->ctx;
389
  cudaStream_t stream = runtime::getCurrentCUDAStream();
390
391

  if (num_parts > 1) {
392
393
    IdArray global_idx =
        aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
394
395

    const dim3 block(128);
396
    const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
397
398

    CUDA_KERNEL_CALL(
399
400
401
        _MapGlobalIndexByRemainderKernel, grid, block, 0, stream,
        static_cast<const IdType*>(local_idx->data), part_id,
        global_idx->shape[0], num_parts,
402
        static_cast<IdType*>(global_idx->data));
403
404
405
406
407
408
409
410

    return global_idx;
  } else {
    // no mapping to be done
    return local_idx;
  }
}

411
412
413
414
template IdArray MapToGlobalFromRemainder<kDGLCUDA, int32_t>(
    int num_parts, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRemainder<kDGLCUDA, int64_t>(
    int num_parts, IdArray in_idx, int part_id);
415

416
417
// Range Based Partition Operations

418
template <DGLDeviceType XPU, typename IdType, typename RangeType>
419
420
std::pair<IdArray, NDArray> GeneratePermutationFromRange(
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx) {
421
422
423
424
  std::pair<IdArray, NDArray> result;

  const auto& ctx = in_idx->ctx;
  auto device = DeviceAPI::Get(ctx);
425
  cudaStream_t stream = runtime::getCurrentCUDAStream();
426
427
428

  const int64_t num_in = in_idx->shape[0];

429
430
  CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
                         << ") must be at least 1.";
431
432
  if (num_parts == 1) {
    // no permutation
433
434
    result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
    result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
435
436
437
438

    return result;
  }

439
440
441
  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
  result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
  int64_t* out_counts = static_cast<int64_t*>(result.second->data);
442
443
444
445
446
447
448
449
450
451
452
453
454
  if (num_in == 0) {
    // now that we've zero'd out_counts, nothing left to do for an empty
    // mapping
    return result;
  }

  const int64_t part_bits =
      static_cast<int64_t>(std::ceil(std::log2(num_parts)));

  // First, generate a mapping of indexes to processors
  Workspace<IdType> proc_id_in(device, ctx, num_in);
  {
    const dim3 block(256);
455
    const dim3 grid((num_in + block.x - 1) / block.x);
456

457
458
    CUDA_KERNEL_CALL(
        _MapProcByRangeKernel, grid, block, 0, stream,
459
        static_cast<const RangeType*>(range->data),
460
        static_cast<const IdType*>(in_idx->data), num_in, num_parts,
461
462
463
464
465
466
        proc_id_in.get());
  }

  // then create a permutation array that groups processors together by
  // performing a radix sort
  Workspace<IdType> proc_id_out(device, ctx, num_in);
467
  IdType* perm_out = static_cast<IdType*>(result.first->data);
468
  {
469
    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
470
471

    size_t sort_workspace_size;
472
473
474
475
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(
        nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
        static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
        stream));
476
477

    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
478
479
480
    CUDA_CALL(cub::DeviceRadixSort::SortPairs(
        sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
        proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
481
482
483
484
485
486
487
488
489
        num_in, 0, part_bits, stream));
  }
  // explicitly free so workspace can be re-used
  proc_id_in.free();

  // perform a histogram and then prefixsum on the sorted proc_id vector

  // Count the number of values to be sent to each processor
  {
490
491
492
    using AtomicCount = unsigned long long;  // NOLINT
    static_assert(
        sizeof(AtomicCount) == sizeof(*out_counts),
493
494
495
496
497
498
        "AtomicCount must be the same width as int64_t for atomicAdd "
        "in cub::DeviceHistogram::HistogramEven() to work");

    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
    // add a compile time check against the cub version to allow
    // num_in > (2 << 31).
499
500
501
    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
        << "number of values to insert into histogram must be less than max "
           "value of int.";
502
503
504

    size_t hist_workspace_size;
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
505
506
507
508
        nullptr, hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
        static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
        static_cast<int>(num_in), stream));
509
510
511

    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(
512
513
514
515
        hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
        static_cast<IdType>(0), static_cast<IdType>(num_parts + 1),
        static_cast<int>(num_in), stream));
516
517
518
519
520
521
  }

  return result;
}

template std::pair<IdArray, IdArray>
522
GeneratePermutationFromRange<kDGLCUDA, int32_t, int32_t>(
523
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
524
template std::pair<IdArray, IdArray>
525
GeneratePermutationFromRange<kDGLCUDA, int64_t, int32_t>(
526
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
527
template std::pair<IdArray, IdArray>
528
GeneratePermutationFromRange<kDGLCUDA, int32_t, int64_t>(
529
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
530
template std::pair<IdArray, IdArray>
531
GeneratePermutationFromRange<kDGLCUDA, int64_t, int64_t>(
532
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
533

534
template <DGLDeviceType XPU, typename IdType, typename RangeType>
535
IdArray MapToLocalFromRange(
536
    const int num_parts, IdArray range, IdArray global_idx) {
537
  const auto& ctx = global_idx->ctx;
538
  cudaStream_t stream = runtime::getCurrentCUDAStream();
539
540

  if (num_parts > 1 && global_idx->shape[0] > 0) {
541
542
    IdArray local_idx =
        aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
543
544

    const dim3 block(128);
545
    const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
546
547

    CUDA_KERNEL_CALL(
548
        _MapLocalIndexByRangeKernel, grid, block, 0, stream,
549
        static_cast<const RangeType*>(range->data),
550
551
        static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
        num_parts, static_cast<IdType*>(local_idx->data));
552
553
554
555
556
557
558
559

    return local_idx;
  } else {
    // no mapping to be done
    return global_idx;
  }
}

560
561
562
563
564
565
566
567
template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx);
568

569
template <DGLDeviceType XPU, typename IdType, typename RangeType>
570
IdArray MapToGlobalFromRange(
571
572
573
574
575
    const int num_parts, IdArray range, IdArray local_idx, const int part_id) {
  CHECK_LT(part_id, num_parts)
      << "Invalid partition id " << part_id << "/" << num_parts;
  CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
                       << num_parts;
576
577

  const auto& ctx = local_idx->ctx;
578
  cudaStream_t stream = runtime::getCurrentCUDAStream();
579
580

  if (num_parts > 1 && local_idx->shape[0] > 0) {
581
582
    IdArray global_idx =
        aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
583
584

    const dim3 block(128);
585
    const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
586
587

    CUDA_KERNEL_CALL(
588
        _MapGlobalIndexByRangeKernel, grid, block, 0, stream,
589
        static_cast<const RangeType*>(range->data),
590
591
        static_cast<const IdType*>(local_idx->data), part_id,
        global_idx->shape[0], num_parts,
592
593
594
595
596
597
598
599
600
        static_cast<IdType*>(global_idx->data));

    return global_idx;
  } else {
    // no mapping to be done
    return local_idx;
  }
}

601
602
603
604
605
606
607
608
template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
609
610
611
612

}  // namespace impl
}  // namespace partition
}  // namespace dgl