partition_op.hip 21.6 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021 by Contributors
5
6
 * @file ndarray_partition.h
 * @brief Operations on partition implemented in CUDA.
7
8
9
10
 */

#include <dgl/runtime/device_api.h>

sangwzh's avatar
sangwzh committed
11
#include <hipcub/hipcub.hpp>
12

13
14
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/workspace.h"
15
#include "../partition_op.h"
16
17
18
19
20
21
22

using namespace dgl::runtime;

namespace dgl {
namespace partition {
namespace impl {

23
24
25
namespace {

/**
26
27
28
29
30
31
32
33
34
 * @brief Kernel to map global element IDs to partition IDs by remainder.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param num_parts The number of partitions.
 * @param part_id The mapped partition ID (outupt).
 */
template <typename IdType>
35
__global__ void _MapProcByRemainderKernel(
36
37
38
39
40
    const IdType* const global, const int64_t num_elements,
    const int64_t num_parts, IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
41

42
43
  if (idx < num_elements) {
    part_id[idx] = global[idx] % num_parts;
44
45
46
  }
}

47
/**
48
49
50
51
52
53
54
55
56
57
58
 * @brief Kernel to map global element IDs to partition IDs, using a bit-mask.
 * The number of partitions must be a power a two.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param mask The bit-mask with 1's for each bit to keep from the element ID to
 * extract the partition ID (e.g., an 8 partition mask would be 0x07).
 * @param part_id The mapped partition ID (outupt).
 */
template <typename IdType>
59
__global__ void _MapProcByMaskRemainderKernel(
60
61
62
63
64
    const IdType* const global, const int64_t num_elements, const IdType mask,
    IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
65

66
67
  if (idx < num_elements) {
    part_id[idx] = global[idx] & mask;
68
69
70
  }
}

71
/**
72
73
74
75
76
77
78
79
80
 * @brief Kernel to map global element IDs to local element IDs.
 *
 * @tparam IdType The type of ID.
 * @param global The global element IDs.
 * @param num_elements The number of IDs.
 * @param num_parts The number of partitions.
 * @param local The local element IDs (output).
 */
template <typename IdType>
81
__global__ void _MapLocalIndexByRemainderKernel(
82
83
84
85
    const IdType* const global, const int64_t num_elements, const int num_parts,
    IdType* const local) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
86

87
88
  if (idx < num_elements) {
    local[idx] = global[idx] / num_parts;
89
90
91
  }
}

92
/**
93
94
95
96
97
98
99
100
101
102
103
 * @brief Kernel to map local element IDs within a partition to their global
 * IDs, using the remainder over the number of partitions.
 *
 * @tparam IdType The type of ID.
 * @param local The local element IDs.
 * @param part_id The partition to map local elements from.
 * @param num_elements The number of elements to map.
 * @param num_parts The number of partitions.
 * @param global The global element IDs (output).
 */
template <typename IdType>
104
__global__ void _MapGlobalIndexByRemainderKernel(
105
106
107
108
    const IdType* const local, const int part_id, const int64_t num_elements,
    const int num_parts, IdType* const global) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
109
110
111
112
113
114
115
116
117

  assert(part_id < num_parts);

  if (idx < num_elements) {
    global[idx] = (local[idx] * num_parts) + part_id;
  }
}

/**
118
119
120
121
122
123
124
125
126
127
128
 * @brief Device function to perform a binary search to find to which partition
 * a given ID belongs.
 *
 * @tparam RangeType The type of range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param num_parts The number of partitions.
 * @param target The element ID to find the partition of.
 *
 * @return The partition.
 */
template <typename RangeType>
129
__device__ RangeType _SearchRange(
130
    const RangeType* const range, const int num_parts, const RangeType target) {
131
132
  int start = 0;
  int end = num_parts;
133
  int cur = (end + start) / 2;
134
135
136
137

  assert(range[0] == 0);
  assert(target < range[num_parts]);

138
  while (start + 1 < end) {
139
140
141
142
143
    if (target < range[cur]) {
      end = cur;
    } else {
      start = cur;
    }
144
    cur = (start + end) / 2;
145
146
147
148
149
150
  }

  return cur;
}

/**
151
152
153
154
155
156
157
158
159
160
161
 * @brief Kernel to map element IDs to partition IDs.
 *
 * @tparam IdType The type of element ID.
 * @tparam RangeType The type of of the range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param global The global element IDs.
 * @param num_elements The number of element IDs.
 * @param num_parts The number of partitions.
 * @param part_id The partition ID assigned to each element (output).
 */
template <typename IdType, typename RangeType>
162
__global__ void _MapProcByRangeKernel(
163
164
165
166
167
168
    const RangeType* const range, const IdType* const global,
    const int64_t num_elements, const int64_t num_parts,
    IdType* const part_id) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
169
170
171
172

  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    part_id[idx] = static_cast<IdType>(_SearchRange(
173
        range, static_cast<int>(num_parts),
174
175
176
177
178
        static_cast<RangeType>(global[idx])));
  }
}

/**
179
180
181
182
183
184
185
186
187
188
189
190
 * @brief Kernel to map global element IDs to their ID within their respective
 * partition.
 *
 * @tparam IdType The type of element ID.
 * @tparam RangeType The type of the range.
 * @param range The prefix-sum of IDs assigned to partitions.
 * @param global The global element IDs.
 * @param num_elements The number of elements.
 * @param num_parts The number of partitions.
 * @param local The local element IDs (output).
 */
template <typename IdType, typename RangeType>
191
__global__ void _MapLocalIndexByRangeKernel(
192
193
194
195
    const RangeType* const range, const IdType* const global,
    const int64_t num_elements, const int num_parts, IdType* const local) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
196
197
198
199

  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    const int proc = _SearchRange(
200
        range, static_cast<int>(num_parts),
201
202
203
204
205
206
        static_cast<RangeType>(global[idx]));
    local[idx] = global[idx] - range[proc];
  }
}

/**
207
208
209
210
211
212
213
214
215
216
217
218
219
 * @brief Kernel to map local element IDs within a partition to their global
 * IDs.
 *
 * @tparam IdType The type of ID.
 * @tparam RangeType The type of the range.
 * @param range The prefix-sum of IDs assigend to partitions.
 * @param local The local element IDs.
 * @param part_id The partition to map local elements from.
 * @param num_elements The number of elements to map.
 * @param num_parts The number of partitions.
 * @param global The global element IDs (output).
 */
template <typename IdType, typename RangeType>
220
__global__ void _MapGlobalIndexByRangeKernel(
221
222
223
224
    const RangeType* const range, const IdType* const local, const int part_id,
    const int64_t num_elements, const int num_parts, IdType* const global) {
  assert(num_elements <= gridDim.x * blockDim.x);
  const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
225

226
  assert(part_id < num_parts);
227

228
229
230
  // rely on caching to load the range into L1 cache
  if (idx < num_elements) {
    global[idx] = local[idx] + range[part_id];
231
232
  }
}
233
234
235
}  // namespace

// Remainder Based Partition Operations
236

237
template <DGLDeviceType XPU, typename IdType>
238
239
std::pair<IdArray, NDArray> GeneratePermutationFromRemainder(
    int64_t array_size, int num_parts, IdArray in_idx) {
240
241
242
243
  std::pair<IdArray, NDArray> result;

  const auto& ctx = in_idx->ctx;
  auto device = DeviceAPI::Get(ctx);
sangwzh's avatar
sangwzh committed
244
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
245
246
247

  const int64_t num_in = in_idx->shape[0];

248
249
  CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
                         << ") must be at least 1.";
250
251
  if (num_parts == 1) {
    // no permutation
252
253
    result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
    result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
254
255
256
257

    return result;
  }

258
259
260
  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
  result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
  int64_t* out_counts = static_cast<int64_t*>(result.second->data);
261
262
263
264
265
266
267
  if (num_in == 0) {
    // now that we've zero'd out_counts, nothing left to do for an empty
    // mapping
    return result;
  }

  const int64_t part_bits =
sangwzh's avatar
sangwzh committed
268
      static_cast<int64_t>(::ceil(std::log2(num_parts)));
269
270
271
272
273

  // First, generate a mapping of indexes to processors
  Workspace<IdType> proc_id_in(device, ctx, num_in);
  {
    const dim3 block(256);
274
    const dim3 grid((num_in + block.x - 1) / block.x);
275
276
277

    if (num_parts < (1 << part_bits)) {
      // num_parts is not a power of 2
278
279
280
      CUDA_KERNEL_CALL(
          _MapProcByRemainderKernel, grid, block, 0, stream,
          static_cast<const IdType*>(in_idx->data), num_in, num_parts,
281
282
283
          proc_id_in.get());
    } else {
      // num_parts is a power of 2
284
285
286
287
      CUDA_KERNEL_CALL(
          _MapProcByMaskRemainderKernel, grid, block, 0, stream,
          static_cast<const IdType*>(in_idx->data), num_in,
          static_cast<IdType>(num_parts - 1),  // bit mask
288
289
290
291
292
293
294
          proc_id_in.get());
    }
  }

  // then create a permutation array that groups processors together by
  // performing a radix sort
  Workspace<IdType> proc_id_out(device, ctx, num_in);
295
  IdType* perm_out = static_cast<IdType*>(result.first->data);
296
  {
297
    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
298
299

    size_t sort_workspace_size;
sangwzh's avatar
sangwzh committed
300
    CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(
301
302
303
        nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
        static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
        stream));
304
305

    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
sangwzh's avatar
sangwzh committed
306
    CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(
307
308
        sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
        proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
309
310
311
312
313
314
315
316
317
        num_in, 0, part_bits, stream));
  }
  // explicitly free so workspace can be re-used
  proc_id_in.free();

  // perform a histogram and then prefixsum on the sorted proc_id vector

  // Count the number of values to be sent to each processor
  {
318
319
320
    using AtomicCount = unsigned long long;  // NOLINT
    static_assert(
        sizeof(AtomicCount) == sizeof(*out_counts),
321
        "AtomicCount must be the same width as int64_t for atomicAdd "
sangwzh's avatar
sangwzh committed
322
        "in hipcub::DeviceHistogram::HistogramEven() to work");
323
324
325
326

    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
    // add a compile time check against the cub version to allow
    // num_in > (2 << 31).
327
328
329
    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
        << "number of values to insert into histogram must be less than max "
           "value of int.";
330
331

    size_t hist_workspace_size;
sangwzh's avatar
sangwzh committed
332
    CUDA_CALL(hipcub::DeviceHistogram::HistogramEven(
333
334
        nullptr, hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
335
        static_cast<IdType>(0), static_cast<IdType>(num_parts),
336
        static_cast<int>(num_in), stream));
337
338

    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
sangwzh's avatar
sangwzh committed
339
    CUDA_CALL(hipcub::DeviceHistogram::HistogramEven(
340
341
        hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
342
        static_cast<IdType>(0), static_cast<IdType>(num_parts),
343
        static_cast<int>(num_in), stream));
344
345
346
347
348
  }

  return result;
}

349
350
351
352
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
    kDGLCUDA, int32_t>(int64_t array_size, int num_parts, IdArray in_idx);
template std::pair<IdArray, IdArray> GeneratePermutationFromRemainder<
    kDGLCUDA, int64_t>(int64_t array_size, int num_parts, IdArray in_idx);
353

354
template <DGLDeviceType XPU, typename IdType>
355
IdArray MapToLocalFromRemainder(const int num_parts, IdArray global_idx) {
356
  const auto& ctx = global_idx->ctx;
sangwzh's avatar
sangwzh committed
357
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
358

359
  if (num_parts > 1) {
360
361
    IdArray local_idx =
        aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
362
363

    const dim3 block(128);
364
    const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
365
366

    CUDA_KERNEL_CALL(
367
368
369
        _MapLocalIndexByRemainderKernel, grid, block, 0, stream,
        static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
        num_parts, static_cast<IdType*>(local_idx->data));
370
371
372
373
374
375

    return local_idx;
  } else {
    // no mapping to be done
    return global_idx;
  }
376
377
}

378
379
380
381
template IdArray MapToLocalFromRemainder<kDGLCUDA, int32_t>(
    int num_parts, IdArray in_idx);
template IdArray MapToLocalFromRemainder<kDGLCUDA, int64_t>(
    int num_parts, IdArray in_idx);
382

383
template <DGLDeviceType XPU, typename IdType>
384
IdArray MapToGlobalFromRemainder(
385
386
387
388
389
    const int num_parts, IdArray local_idx, const int part_id) {
  CHECK_LT(part_id, num_parts)
      << "Invalid partition id " << part_id << "/" << num_parts;
  CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
                       << num_parts;
390
391

  const auto& ctx = local_idx->ctx;
sangwzh's avatar
sangwzh committed
392
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
393
394

  if (num_parts > 1) {
395
396
    IdArray global_idx =
        aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
397
398

    const dim3 block(128);
399
    const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
400
401

    CUDA_KERNEL_CALL(
402
403
404
        _MapGlobalIndexByRemainderKernel, grid, block, 0, stream,
        static_cast<const IdType*>(local_idx->data), part_id,
        global_idx->shape[0], num_parts,
405
        static_cast<IdType*>(global_idx->data));
406
407
408
409
410
411
412
413

    return global_idx;
  } else {
    // no mapping to be done
    return local_idx;
  }
}

414
415
416
417
template IdArray MapToGlobalFromRemainder<kDGLCUDA, int32_t>(
    int num_parts, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRemainder<kDGLCUDA, int64_t>(
    int num_parts, IdArray in_idx, int part_id);
418

419
420
// Range Based Partition Operations

421
template <DGLDeviceType XPU, typename IdType, typename RangeType>
422
423
std::pair<IdArray, NDArray> GeneratePermutationFromRange(
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx) {
424
425
426
427
  std::pair<IdArray, NDArray> result;

  const auto& ctx = in_idx->ctx;
  auto device = DeviceAPI::Get(ctx);
sangwzh's avatar
sangwzh committed
428
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
429
430
431

  const int64_t num_in = in_idx->shape[0];

432
433
  CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts
                         << ") must be at least 1.";
434
435
  if (num_parts == 1) {
    // no permutation
436
437
    result.first = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
    result.second = aten::Full(num_in, num_parts, sizeof(int64_t) * 8, ctx);
438
439
440
441

    return result;
  }

442
443
444
  result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType) * 8);
  result.second = aten::Full(0, num_parts, sizeof(int64_t) * 8, ctx);
  int64_t* out_counts = static_cast<int64_t*>(result.second->data);
445
446
447
448
449
450
451
  if (num_in == 0) {
    // now that we've zero'd out_counts, nothing left to do for an empty
    // mapping
    return result;
  }

  const int64_t part_bits =
sangwzh's avatar
sangwzh committed
452
      static_cast<int64_t>(::ceil(std::log2(num_parts)));
453
454
455
456
457

  // First, generate a mapping of indexes to processors
  Workspace<IdType> proc_id_in(device, ctx, num_in);
  {
    const dim3 block(256);
458
    const dim3 grid((num_in + block.x - 1) / block.x);
459

460
461
    CUDA_KERNEL_CALL(
        _MapProcByRangeKernel, grid, block, 0, stream,
462
        static_cast<const RangeType*>(range->data),
463
        static_cast<const IdType*>(in_idx->data), num_in, num_parts,
464
465
466
467
468
469
        proc_id_in.get());
  }

  // then create a permutation array that groups processors together by
  // performing a radix sort
  Workspace<IdType> proc_id_out(device, ctx, num_in);
470
  IdType* perm_out = static_cast<IdType*>(result.first->data);
471
  {
472
    IdArray perm_in = aten::Range(0, num_in, sizeof(IdType) * 8, ctx);
473
474

    size_t sort_workspace_size;
sangwzh's avatar
sangwzh committed
475
    CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(
476
477
478
        nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(),
        static_cast<IdType*>(perm_in->data), perm_out, num_in, 0, part_bits,
        stream));
479
480

    Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
sangwzh's avatar
sangwzh committed
481
    CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(
482
483
        sort_workspace.get(), sort_workspace_size, proc_id_in.get(),
        proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
484
485
486
487
488
489
490
491
492
        num_in, 0, part_bits, stream));
  }
  // explicitly free so workspace can be re-used
  proc_id_in.free();

  // perform a histogram and then prefixsum on the sorted proc_id vector

  // Count the number of values to be sent to each processor
  {
493
494
495
    using AtomicCount = unsigned long long;  // NOLINT
    static_assert(
        sizeof(AtomicCount) == sizeof(*out_counts),
496
        "AtomicCount must be the same width as int64_t for atomicAdd "
sangwzh's avatar
sangwzh committed
497
        "in hipcub::DeviceHistogram::HistogramEven() to work");
498
499
500
501

    // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
    // add a compile time check against the cub version to allow
    // num_in > (2 << 31).
502
503
504
    CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max()))
        << "number of values to insert into histogram must be less than max "
           "value of int.";
505
506

    size_t hist_workspace_size;
sangwzh's avatar
sangwzh committed
507
    CUDA_CALL(hipcub::DeviceHistogram::HistogramEven(
508
509
        nullptr, hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
510
        static_cast<IdType>(0), static_cast<IdType>(num_parts),
511
        static_cast<int>(num_in), stream));
512
513

    Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
sangwzh's avatar
sangwzh committed
514
    CUDA_CALL(hipcub::DeviceHistogram::HistogramEven(
515
516
        hist_workspace.get(), hist_workspace_size, proc_id_out.get(),
        reinterpret_cast<AtomicCount*>(out_counts), num_parts + 1,
517
        static_cast<IdType>(0), static_cast<IdType>(num_parts),
518
        static_cast<int>(num_in), stream));
519
520
521
522
523
524
  }

  return result;
}

template std::pair<IdArray, IdArray>
525
GeneratePermutationFromRange<kDGLCUDA, int32_t, int32_t>(
526
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
527
template std::pair<IdArray, IdArray>
528
GeneratePermutationFromRange<kDGLCUDA, int64_t, int32_t>(
529
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
530
template std::pair<IdArray, IdArray>
531
GeneratePermutationFromRange<kDGLCUDA, int32_t, int64_t>(
532
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
533
template std::pair<IdArray, IdArray>
534
GeneratePermutationFromRange<kDGLCUDA, int64_t, int64_t>(
535
    int64_t array_size, int num_parts, IdArray range, IdArray in_idx);
536

537
template <DGLDeviceType XPU, typename IdType, typename RangeType>
538
IdArray MapToLocalFromRange(
539
    const int num_parts, IdArray range, IdArray global_idx) {
540
  const auto& ctx = global_idx->ctx;
sangwzh's avatar
sangwzh committed
541
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
542
543

  if (num_parts > 1 && global_idx->shape[0] > 0) {
544
545
    IdArray local_idx =
        aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType) * 8);
546
547

    const dim3 block(128);
548
    const dim3 grid((global_idx->shape[0] + block.x - 1) / block.x);
549
550

    CUDA_KERNEL_CALL(
551
        _MapLocalIndexByRangeKernel, grid, block, 0, stream,
552
        static_cast<const RangeType*>(range->data),
553
554
        static_cast<const IdType*>(global_idx->data), global_idx->shape[0],
        num_parts, static_cast<IdType*>(local_idx->data));
555
556
557
558
559
560
561
562

    return local_idx;
  } else {
    // no mapping to be done
    return global_idx;
  }
}

563
564
565
566
567
568
569
570
template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int32_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx);
template IdArray MapToLocalFromRange<kDGLCUDA, int64_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx);
571

572
template <DGLDeviceType XPU, typename IdType, typename RangeType>
573
IdArray MapToGlobalFromRange(
574
575
576
577
578
    const int num_parts, IdArray range, IdArray local_idx, const int part_id) {
  CHECK_LT(part_id, num_parts)
      << "Invalid partition id " << part_id << "/" << num_parts;
  CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/"
                       << num_parts;
579
580

  const auto& ctx = local_idx->ctx;
sangwzh's avatar
sangwzh committed
581
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
582
583

  if (num_parts > 1 && local_idx->shape[0] > 0) {
584
585
    IdArray global_idx =
        aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType) * 8);
586
587

    const dim3 block(128);
588
    const dim3 grid((local_idx->shape[0] + block.x - 1) / block.x);
589
590

    CUDA_KERNEL_CALL(
591
        _MapGlobalIndexByRangeKernel, grid, block, 0, stream,
592
        static_cast<const RangeType*>(range->data),
593
594
        static_cast<const IdType*>(local_idx->data), part_id,
        global_idx->shape[0], num_parts,
595
596
597
598
599
600
601
602
603
        static_cast<IdType*>(global_idx->data));

    return global_idx;
  } else {
    // no mapping to be done
    return local_idx;
  }
}

604
605
606
607
608
609
610
611
template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int32_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int32_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
template IdArray MapToGlobalFromRange<kDGLCUDA, int64_t, int64_t>(
    int num_parts, IdArray range, IdArray in_idx, int part_id);
612
613
614
615

}  // namespace impl
}  // namespace partition
}  // namespace dgl