nccl_api.cu 25.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021-2022 by Contributors
3
4
5
6
7
8
9
10
11
12
13
14
15
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
16
17
 * @file nccl_api.cu
 * @brief Implementation of wrapper around NCCL routines.
18
19
 */

20
21
#include <cuda_fp16.h>
#include <cuda_runtime.h>
22
23
#include <dgl/array.h>
#include <dgl/aten/array_ops.h>
24
#include <dgl/packed_func_ext.h>
25
26
27
28
#include <dgl/runtime/container.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>

29
#include <algorithm>
30
31
#include <cmath>
#include <iomanip>
32
#include <limits>
33
#include <memory>
34
#include <sstream>
35
#include <string>
36
37
#include <utility>
#include <vector>
38
39

#include "../../array/cuda/array_index_select.cuh"
40
41
42
43
44
#include "../../array/cuda/dgl_cub.cuh"
#include "../../partition/ndarray_partition.h"
#include "../../runtime/workspace.h"
#include "cuda_common.h"
#include "nccl_api.h"
45

46
47
48
49
50
51
52
#define NCCL_CALL(func)                                                   \
  {                                                                       \
    ncclResult_t result = func;                                           \
    if (result != ncclSuccess) {                                          \
      LOG(FATAL) << "NCCLError: " #func " failed with error: " << result; \
    }                                                                     \
  }
53
54
55
56
57
58
59
60
61
62

namespace dgl {

using namespace partition;

namespace runtime {
namespace cuda {

namespace {

63
#ifdef DGL_USE_NCCL
64

65
66
67
68
69
template <typename T>
ncclDataType_t NCCLType();
template <>
ncclDataType_t NCCLType<int32_t>() {
  return ncclInt32;
70
}
71
72
73
template <>
ncclDataType_t NCCLType<int64_t>() {
  return ncclInt64;
74
}
75
76
77
template <>
ncclDataType_t NCCLType<__half>() {
  return ncclHalf;
78
}
79
80
81
template <>
ncclDataType_t NCCLType<float>() {
  return ncclFloat32;
82
}
83
84
85
template <>
ncclDataType_t NCCLType<double>() {
  return ncclFloat64;
86
87
}

88
#endif  // DGL_USE_NCCL
89

90
template <typename IdType, typename DType>
91
__global__ void _DualPermKernel(
92
93
94
    const IdType* const in_idx, const DType* const in_value,
    const IdType* const perm, const int64_t num_in, const int64_t num_feat,
    IdType* const out_idx, DType* const out_value) {
95
  // set index permutation
96
97
  const int64_t tidx =
      blockDim.x * static_cast<int64_t>(blockIdx.x) + threadIdx.x;
98
99
100
101
102
103
104
105
  if (tidx < num_in) {
    const IdType perm_idx = perm[tidx];
    assert(perm_idx < num_in);
    out_idx[tidx] = in_idx[perm_idx];
  }

  if (num_feat > 1) {
    for (int d = 0; d < blockDim.x; ++d) {
106
      const int64_t bidx = blockDim.x * static_cast<int64_t>(blockIdx.x) + d;
107
108
      if (bidx < num_in) {
        const IdType perm_idx = perm[bidx];
109
110
        for (int64_t f = threadIdx.x; f < num_feat; f += blockDim.x) {
          out_value[bidx * num_feat + f] = in_value[perm_idx * num_feat + f];
111
112
113
114
115
116
117
118
119
120
121
122
123
        }
      }
    }
  } else {
    if (tidx < num_in) {
      const IdType perm_idx = perm[tidx];
      out_value[tidx] = in_value[perm_idx];
    }
  }
}

template <typename DType, typename IdType>
__global__ void _InversePermKernel(
124
125
126
    const DType* const array, const int64_t num_feat, int64_t length,
    const IdType* const perm, DType* const out) {
  int64_t in_row = blockIdx.x * blockDim.y + threadIdx.y;
127

128
  const int64_t stride = blockDim.y * gridDim.x;
129
130
131
132
133

  while (in_row < length) {
    int64_t col = threadIdx.x;
    const int64_t out_row = perm[in_row];
    while (col < num_feat) {
134
      out[out_row * num_feat + col] = array[in_row * num_feat + col];
135
136
137
138
139
140
      col += blockDim.x;
    }
    in_row += stride;
  }
}

141
template <typename IdType, typename DType>
142
std::pair<IdArray, NDArray> SparsePush(
143
    NCCLCommunicatorRef comm, IdArray in_idx, NDArray in_value,
144
145
146
    NDArrayPartitionRef part) {
  const auto& ctx = in_idx->ctx;
  CHECK_EQ(ctx, in_value->ctx) << "Indices and values must be on the same "
147
                                  "device";
148
149
  auto device = DeviceAPI::Get(ctx);

150
  cudaStream_t stream = runtime::getCurrentCUDAStream();
151

152
  CHECK_LE(in_idx->ndim, 1) << "The tensor of sending indices must be of "
153
                               "dimension one (or empty).";
154
155
  const int64_t num_in = in_idx->ndim > 0 ? in_idx->shape[0] : 0;

156
157
158
159
160
  CHECK_EQ(num_in, in_value->ndim > 0 ? in_value->shape[0] : 0)
      << "Leading dimension of indices (" << num_in
      << ") must match "
         "leading dimension of values ("
      << (in_value->ndim > 0 ? in_value->shape[0] : 0) << ").";
161
162
163
164
165
166
167
168
169
170
171
172
173
174

  int64_t num_feat = 1;
  for (int d = 1; d < in_value->ndim; ++d) {
    num_feat *= in_value->shape[d];
  }

  const int64_t comm_size = comm->size();

  if (comm_size == 1) {
    // nothing to do, just return original arrays
    return std::pair<IdArray, NDArray>(in_idx, in_value);
  }

  std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(in_idx);
175
176
  const IdType* const perm = static_cast<const IdType*>(part_perm.first->data);
  const int64_t* const send_sum =
177
178
179
      static_cast<const int64_t*>(part_perm.second->data);

  Workspace<IdType> send_idx(device, ctx, num_in);
180
  Workspace<DType> send_value(device, ctx, num_in * num_feat);
181
182

  // permute the indices and values
183
  if (num_in > 0) {
184
    const dim3 block(256);
185
    const dim3 grid((num_in + block.x - 1) / block.x);
186

187
188
    CUDA_KERNEL_CALL(
        _DualPermKernel, grid, block, 0, stream,
189
        static_cast<const IdType*>(in_idx->data),
190
191
        static_cast<const DType*>(in_value->data), perm, num_in, num_feat,
        send_idx.get(), send_value.get());
192
193
194
  }

  // compute the prefix sum of the send values
195
  Workspace<int64_t> send_prefix(device, ctx, comm_size + 1);
196
197
  {
    size_t prefix_workspace_size;
198
199
200
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(
        nullptr, prefix_workspace_size, send_sum, send_prefix.get(),
        comm_size + 1, stream));
201
202

    Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
203
204
205
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(
        prefix_workspace.get(), prefix_workspace_size, send_sum,
        send_prefix.get(), comm_size + 1, stream));
206
207
  }

208
  std::vector<int64_t> send_prefix_host(comm_size + 1);
209
  // copy using the same stream (local current stream), no need to sync
210
  device->CopyDataFromTo(
211
212
      send_prefix.get(), 0, send_prefix_host.data(), 0,
      send_prefix_host.size() * sizeof(*send_prefix.get()), ctx,
213
      DGLContext{kDGLCPU, 0},
214
      DGLDataType{kDGLInt, sizeof(*send_prefix.get()) * 8, 1});
215
216
  send_prefix.free();

217
218
219
220
  CHECK_EQ(send_prefix_host.back(), num_in)
      << "Internal Error: "
         "send_prefix_host.back() = "
      << send_prefix_host.back() << ", and num_in = " << num_in;
221
222

  // communicate the amount to send
223
  Workspace<int64_t> recv_sum(device, ctx, comm_size + 1);
224
225
226
  comm->AllToAll(send_sum, recv_sum.get(), 1, stream);

  cudaEvent_t d2h;
227
  CUDA_CALL(cudaEventCreate(&d2h));
228
229

  // compute the prefix sum of the recv values
230
  Workspace<int64_t> recv_prefix(device, ctx, comm_size + 1);
231
232
  {
    size_t prefix_workspace_size;
233
234
235
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(
        nullptr, prefix_workspace_size, recv_sum.get(), recv_prefix.get(),
        comm_size + 1, stream));
236
237

    Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
238
239
240
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(
        prefix_workspace.get(), prefix_workspace_size, recv_sum.get(),
        recv_prefix.get(), comm_size + 1, stream));
241
242
243
244
  }
  recv_sum.free();

  // finally copy the prefixsum sum down to the host
245
  std::vector<int64_t> recv_prefix_host(comm_size + 1);
246
  // copy using the same stream (local current stream), no need to sync
247
  device->CopyDataFromTo(
248
249
      recv_prefix.get(), 0, recv_prefix_host.data(), 0,
      recv_prefix_host.size() * sizeof(*recv_prefix.get()), ctx,
250
      DGLContext{kDGLCPU, 0},
251
      DGLDataType{kDGLInt, sizeof(*recv_prefix.get()) * 8, 1});
252
253
254
  recv_prefix.free();

  // use an event to track when copying is done
255
  CUDA_CALL(cudaEventRecord(d2h, stream));
256
257

  // allocate output space
258
259
  CUDA_CALL(cudaEventSynchronize(d2h));
  CUDA_CALL(cudaEventDestroy(d2h));
260

261
262
  IdArray recv_idx =
      aten::NewIdArray(recv_prefix_host.back(), ctx, sizeof(IdType) * 8);
263
264
265
266
267
268
269
270
271
272

  std::vector<int64_t> value_shape(in_value->ndim, 0);
  value_shape[0] = recv_prefix_host.back();
  for (int d = 1; d < in_value->ndim; ++d) {
    value_shape[d] = in_value->shape[d];
  }
  NDArray recv_value = NDArray::Empty(value_shape, in_value->dtype, ctx);

  // send data
  comm->SparseAllToAll(
273
      send_idx.get(), send_value.get(), num_feat, send_prefix_host.data(),
274
      static_cast<IdType*>(recv_idx->data),
275
      static_cast<DType*>(recv_value->data), recv_prefix_host.data(), stream);
276
277
278
279

  return std::pair<IdArray, NDArray>(recv_idx, recv_value);
}

280
template <typename IdType, typename DType>
281
NDArray SparsePull(
282
    NCCLCommunicatorRef comm, IdArray req_idx, NDArray local_tensor,
283
284
285
    NDArrayPartitionRef part) {
  const auto& ctx = req_idx->ctx;
  CHECK_EQ(ctx, local_tensor->ctx) << "The request indices and set of local "
286
                                      "values must be on the same device";
287
288
  auto device = DeviceAPI::Get(ctx);

289
  cudaStream_t stream = runtime::getCurrentCUDAStream();
290

291
  CHECK_LE(req_idx->ndim, 1) << "The tensor of requested indices must be of "
292
                                "dimension one (or empty).";
293
  const int64_t num_in = req_idx->ndim > 0 ? req_idx->shape[0] : 0;
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  int64_t num_feat = 1;
  for (int d = 1; d < local_tensor->ndim; ++d) {
    num_feat *= local_tensor->shape[d];
  }

  const int64_t comm_size = comm->size();

  if (comm_size == 1) {
    // Just return index selection from current local_tensor
    return aten::IndexSelect(local_tensor, req_idx);
  }

  // First we need to send our requests to other processors. This means
  // re-ordering our index array to be contiguous among processors, and
  // counting the number of indices we are sending each processor. For now,
  // we assume a poorly partitioned graph, and that there exists the
  // possibility that each processor could request data from this one.

  // the buffer for us to re-order our requests in
  Workspace<IdType> send_idx(device, ctx, num_in);

  std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(req_idx);
316
317
  const IdType* const perm = static_cast<const IdType*>(part_perm.first->data);
  const int64_t* const send_sum =
318
319
320
      static_cast<const int64_t*>(part_perm.second->data);

  // permute requests
321
  if (num_in > 0) {
322
    const dim3 block(256);
323
324
325
326
327
328
    const dim3 grid((num_in + block.x - 1) / block.x);

    CUDA_KERNEL_CALL(
        aten::impl::IndexSelectSingleKernel, grid, block, 0, stream,
        static_cast<const IdType*>(req_idx->data), perm, num_in,
        req_idx->shape[0], send_idx.get());
329
330
331
  }

  // compute the prefix sum of the indexes this process is requesting
332
  Workspace<int64_t> request_prefix(device, ctx, comm_size + 1);
333
334
  {
    size_t prefix_workspace_size;
335
336
337
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(
        nullptr, prefix_workspace_size, send_sum, request_prefix.get(),
        comm_size + 1, stream));
338
339

    Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
340
341
342
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(
        prefix_workspace.get(), prefix_workspace_size, send_sum,
        request_prefix.get(), comm_size + 1, stream));
343
344
345
  }

  cudaEvent_t d2h;
346
  CUDA_CALL(cudaEventCreate(&d2h));
347

348
  std::vector<int64_t> request_prefix_host(comm_size + 1);
349
  // copy using the same stream (local current stream), no need to sync
350
  device->CopyDataFromTo(
351
352
      request_prefix.get(), 0, request_prefix_host.data(), 0,
      request_prefix_host.size() * sizeof(*request_prefix.get()), ctx,
353
      DGLContext{kDGLCPU, 0},
354
      DGLDataType{kDGLInt, sizeof(*request_prefix.get()) * 8, 1});
355
  request_prefix.free();
356
357
358
359
  CHECK_EQ(request_prefix_host.back(), num_in)
      << "Internal Error: "
         "request_prefix_host.back() = "
      << request_prefix_host.back() << ", num_in = " << num_in;
360
361

  // communicate the amount requested
362
  Workspace<int64_t> recv_sum(device, ctx, comm_size + 1);
363
364
365
  comm->AllToAll(send_sum, recv_sum.get(), 1, stream);

  // compute the prefix sum of the requested indexes
366
  Workspace<int64_t> response_prefix(device, ctx, comm_size + 1);
367
368
  {
    size_t prefix_workspace_size;
369
370
371
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(
        nullptr, prefix_workspace_size, recv_sum.get(), response_prefix.get(),
        comm_size + 1, stream));
372
373

    Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
374
375
376
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(
        prefix_workspace.get(), prefix_workspace_size, recv_sum.get(),
        response_prefix.get(), comm_size + 1, stream));
377
378
379
380
  }
  recv_sum.free();

  // finally copy the prefixsum sum down to the host
381
  std::vector<int64_t> response_prefix_host(comm_size + 1);
382
  // copy using the same stream (local current stream), no need to sync
383
  device->CopyDataFromTo(
384
385
      response_prefix.get(), 0, response_prefix_host.data(), 0,
      response_prefix_host.size() * sizeof(*response_prefix.get()), ctx,
386
      DGLContext{kDGLCPU, 0},
387
      DGLDataType{kDGLInt, sizeof(*response_prefix.get()) * 8, 1});
388
389
390
  response_prefix.free();

  // use an event to track when copying is done
391
  CUDA_CALL(cudaEventRecord(d2h, stream));
392
393

  // allocate output space
394
395
  CUDA_CALL(cudaEventSynchronize(d2h));
  CUDA_CALL(cudaEventDestroy(d2h));
396
397

  // gather requested indexes
398
399
  IdArray recv_idx =
      aten::NewIdArray(response_prefix_host.back(), ctx, sizeof(IdType) * 8);
400
  comm->AllToAllV(
401
402
      send_idx.get(), request_prefix_host.data(),
      static_cast<IdType*>(recv_idx->data), response_prefix_host.data(),
403
404
405
406
407
408
409
410
411
      stream);
  send_idx.free();

  // convert requested indices to local indices depending on partition
  if (response_prefix_host.back() > 0) {
    recv_idx = part->MapToLocal(recv_idx);
  }

  // and then index select them into place
412
413
  Workspace<DType> filled_response_value(
      device, ctx, response_prefix_host.back() * num_feat);
414
  if (response_prefix_host.back() > 0) {
415
    dim3 block(256, 1);
416
417
418
    while (block.x >= 2 * num_feat) {
      block.x /= 2;
      block.y *= 2;
419
    }
420
421
422
423
424
425
426
    const dim3 grid((response_prefix_host.back() + block.y - 1) / block.y);

    CUDA_KERNEL_CALL(
        aten::impl::IndexSelectMultiKernel, grid, block, 0, stream,
        static_cast<const DType*>(local_tensor->data), num_feat,
        static_cast<IdType*>(recv_idx->data), response_prefix_host.back(),
        local_tensor->shape[0], filled_response_value.get());
427
428
429
430
431
432
433
434
  }

  // we will collect recieved values in this array
  std::vector<int64_t> value_shape(local_tensor->ndim, 0);
  value_shape[0] = request_prefix_host.back();
  for (int d = 1; d < local_tensor->ndim; ++d) {
    value_shape[d] = local_tensor->shape[d];
  }
435
436
  Workspace<DType> filled_request_value(
      device, ctx, request_prefix_host.back() * num_feat);
437
438
439
440
441
442
443
444
445
446
447

  // multiply the prefixes by the number of features being sent
  for (auto& v : request_prefix_host) {
    v *= num_feat;
  }
  for (auto& v : response_prefix_host) {
    v *= num_feat;
  }

  // send the values
  comm->AllToAllV(
448
449
      filled_response_value.get(), response_prefix_host.data(),
      filled_request_value.get(), request_prefix_host.data(), stream);
450
451
452
453
454
455
  filled_response_value.free();

  // finally, we need to permute the values back into the requested order
  NDArray result = NDArray::Empty(value_shape, local_tensor->dtype, ctx);
  if (num_in > 0) {
    dim3 block(256, 1);
456
457
458
    while (block.x >= 2 * num_feat) {
      block.x /= 2;
      block.y *= 2;
459
    }
460
461
462
463
464
    const dim3 grid((num_in + block.y - 1) / block.y);

    CUDA_KERNEL_CALL(
        _InversePermKernel, grid, block, 0, stream, filled_request_value.get(),
        num_feat, num_in, perm, static_cast<DType*>(result->data));
465
466
467
468
469
470
471
472
473
  }

  return result;
}

}  // namespace

/* NCCLUniqueId **************************************************************/

474
475
NCCLUniqueId::NCCLUniqueId() : id_() {
#ifdef DGL_USE_NCCL
476
477
  // this ID is unique to the process, not to each call of this function
  NCCL_CALL(ncclGetUniqueId(&id_));
478
#else
479
  // when NCCL isn't enabled, use all zeros
480
481
482
  std::fill(
      id_.internal, id_.internal + NCCL_UNIQUE_ID_BYTES, static_cast<char>(0));
#endif
483
484
}

485
ncclUniqueId NCCLUniqueId::Get() const { return id_; }
486
487
488
489
490
491
492
493
494
495
496
497

std::string NCCLUniqueId::ToString() const {
  std::ostringstream oss;

  oss << std::hex;

  for (size_t b = 0; b < NCCL_UNIQUE_ID_BYTES; ++b) {
    const int num = static_cast<uint8_t>(id_.internal[b]);
    oss << std::setw(2) << std::setfill('0') << num;
  }

  std::string result = oss.str();
498
499
  CHECK_EQ(result.length(), NCCL_UNIQUE_ID_BYTES * 2)
      << "Invalid NCCL ID format: '" << result << "'";
500
501
502
503

  return result;
}

504
void NCCLUniqueId::FromString(const std::string& str) {
505
  // must be exactly 256 hex characters
506
507
  CHECK_EQ(str.length(), NCCL_UNIQUE_ID_BYTES * 2)
      << "Invalid NCCL ID format: '" << str << "'";
508
509

  for (size_t b = 0; b < NCCL_UNIQUE_ID_BYTES; ++b) {
510
    id_.internal[b] = std::strtol(str.substr(b * 2, 2).c_str(), nullptr, 16);
511
512
513
514
515
516
  }
}

/* NCCLCommunicator **********************************************************/

NCCLCommunicator::NCCLCommunicator(
517
518
519
520
521
522
523
524
525
526
527
    const int size, const int rank, ncclUniqueId id)
    : comm_(), size_(size), rank_(rank) {
  CHECK_LT(rank, size) << "The rank (" << rank
                       << ") must be smaller than "
                          "the size of the communicator ("
                       << size << ").";
  CHECK_GE(rank, 0) << "The rank (" << rank
                    << ") must be greater than or "
                       "equal to 0.";

#ifdef DGL_USE_NCCL
528
  NCCL_CALL(ncclCommInitRank(&comm_, size_, id, rank_));
529
530
531
532
533
534
535
#else
  CHECK_EQ(size, 1)
      << "Cannot create a communicator of size " << size
      << ". "
         "To use a communicator size greater than 1, compile DGL with NCCL "
         "support.";
#endif
536
537
538
}

NCCLCommunicator::~NCCLCommunicator() {
539
#ifdef DGL_USE_NCCL
540
  ncclCommDestroy(comm_);
541
#endif
542
543
}

544
ncclComm_t NCCLCommunicator::Get() { return comm_; }
545

546
template <typename DType>
547
void NCCLCommunicator::AllToAllV(
548
549
550
    const DType* const send, const int64_t* const send_prefix,
    DType* const recv, const int64_t* const recv_prefix, cudaStream_t stream) {
#ifdef DGL_USE_NCCL
551
552
553
554
  const ncclDataType_t type = NCCLType<DType>();

  NCCL_CALL(ncclGroupStart());
  for (int r = 0; r < size_; ++r) {
555
    const int64_t send_size = send_prefix[r + 1] - send_prefix[r];
556
    if (send_size > 0) {
557
558
      NCCL_CALL(
          ncclSend(send + send_prefix[r], send_size, type, r, comm_, stream));
559
    }
560
    const int64_t recv_size = recv_prefix[r + 1] - recv_prefix[r];
561
    if (recv_size > 0) {
562
563
      NCCL_CALL(
          ncclRecv(recv + recv_prefix[r], recv_size, type, r, comm_, stream));
564
565
566
    }
  }
  NCCL_CALL(ncclGroupEnd());
567
568
569
#else
  CHECK_EQ(send_prefix[1] - send_prefix[0], recv_prefix[1] - recv_prefix[0])
      << "Send message size must equal receive message size.";
570
571
572

  int dev_id;
  CUDA_CALL(cudaGetDevice(&dev_id));
573
  DGLContext ctx{kDGLCUDA, dev_id};
574
575

  auto device = runtime::DeviceAPI::Get(ctx);
576
  auto dtype = DGLDataTypeTraits<DType>::dtype;
577

578
  // copy using the same stream (local current stream), no need to sync
579
580
581
582
  device->CopyDataFromTo(
      send, send_prefix[0], recv, recv_prefix[0],
      sizeof(DType) * send_prefix[1] - send_prefix[0], ctx, ctx, dtype);
#endif
583
584
}

585
586
587
588
589
590
591
592
593
594
595
596
597
598
template void NCCLCommunicator::AllToAllV<int32_t>(
    const int32_t* const send, const int64_t* send_prefix, int32_t* const recv,
    const int64_t* recv_prefix, cudaStream_t stream);
template void NCCLCommunicator::AllToAllV<int64_t>(
    const int64_t* const send, const int64_t* send_prefix, int64_t* const recv,
    const int64_t* recv_prefix, cudaStream_t stream);
template void NCCLCommunicator::AllToAllV<float>(
    const float* const send, const int64_t* send_prefix, float* const recv,
    const int64_t* recv_prefix, cudaStream_t stream);
template void NCCLCommunicator::AllToAllV<__half>(
    const __half* const send, const int64_t* send_prefix, __half* const recv,
    const int64_t* recv_prefix, cudaStream_t stream);

template <typename IdType>
599
void NCCLCommunicator::AllToAll(
600
    const IdType* const send, IdType* const recv, const int64_t count,
601
    cudaStream_t stream) {
602
#ifdef DGL_USE_NCCL
603
604
  const ncclDataType_t type = NCCLType<IdType>();

605
  NCCL_CALL(ncclGroupStart());
606
  for (int r = 0; r < size_; ++r) {
607
608
    NCCL_CALL(ncclSend(send + (r * count), count, type, r, comm_, stream));
    NCCL_CALL(ncclRecv(recv + (r * count), count, type, r, comm_, stream));
609
  }
610
  NCCL_CALL(ncclGroupEnd());
611
#else
612
613
  int dev_id;
  CUDA_CALL(cudaGetDevice(&dev_id));
614
  DGLContext ctx{kDGLCUDA, dev_id};
615
616

  auto device = runtime::DeviceAPI::Get(ctx);
617
  auto dtype = DGLDataTypeTraits<IdType>::dtype;
618

619
  // copy using the same stream (local current stream), no need to sync
620
  device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype);
621
#endif
622
623
}

624
625
template void NCCLCommunicator::AllToAll<int32_t>(
    const int32_t* const send, int32_t* const recv, const int64_t count,
626
    cudaStream_t stream);
627
628
template void NCCLCommunicator::AllToAll<int64_t>(
    const int64_t* const send, int64_t* const recv, const int64_t count,
629
630
    cudaStream_t stream);

631
template <typename IdType, typename DType>
632
void NCCLCommunicator::SparseAllToAll(
633
634
635
636
    const IdType* const send_idx, const DType* const send_value,
    const int64_t num_feat, const int64_t* const send_prefix,
    IdType* const recv_idx, DType* const recv_value,
    const int64_t* const recv_prefix, cudaStream_t stream) {
637
638
639
  // idxs
  AllToAllV(send_idx, send_prefix, recv_idx, recv_prefix, stream);

640
  // scale prefixes by number of features
641
642
643
  std::vector<int64_t> value_send_prefix(size_ + 1);
  for (int r = 0; r < size_ + 1; ++r) {
    value_send_prefix[r] = send_prefix[r] * num_feat;
644
  }
645
646
647
  std::vector<int64_t> value_recv_prefix(size_ + 1);
  for (int r = 0; r < size_ + 1; ++r) {
    value_recv_prefix[r] = recv_prefix[r] * num_feat;
648
  }
649
650
651
  AllToAllV(
      send_value, value_send_prefix.data(), recv_value,
      value_recv_prefix.data(), stream);
652
653
}

654
655
656
657
658
659
660
661
662
663
template void NCCLCommunicator::SparseAllToAll<int32_t, __half>(
    const int32_t* const send_idx, const __half* const send_value,
    const int64_t num_feat, const int64_t* const send_prefix,
    int32_t* const recv_idx, __half* const recv_value,
    const int64_t* const recv_prefix, cudaStream_t stream);
template void NCCLCommunicator::SparseAllToAll<int64_t, __half>(
    const int64_t* const send_idx, const __half* const send_value,
    const int64_t num_feat, const int64_t* const send_prefix,
    int64_t* const recv_idx, __half* const recv_value,
    const int64_t* const recv_prefix, cudaStream_t stream);
664

665
int NCCLCommunicator::size() const { return size_; }
666

667
int NCCLCommunicator::rank() const { return rank_; }
668
669
670
671

/* CAPI **********************************************************************/

DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLGetUniqueId")
672
673
674
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = NCCLUniqueIdRef(std::make_shared<NCCLUniqueId>());
    });
675
676

DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLUniqueIdToString")
677
678
679
680
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NCCLUniqueIdRef idObj = args[0];
      *rv = idObj->ToString();
    });
681
682

DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLUniqueIdFromString")
683
684
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const std::string str = args[0];
685

686
687
688
689
      NCCLUniqueIdRef ref(std::make_shared<NCCLUniqueId>());
      ref->FromString(str);
      *rv = ref;
    });
690
691

DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLCreateComm")
692
693
694
695
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int size = args[0];
      const int rank = args[1];
      NCCLUniqueIdRef idObj = args[2];
696

697
698
699
      *rv = NCCLCommunicatorRef(
          std::make_shared<NCCLCommunicator>(size, rank, idObj->Get()));
    });
700
701

DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPush")
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NCCLCommunicatorRef comm = args[0];
      IdArray in_idx = args[1];
      NDArray in_values = args[2];
      NDArrayPartitionRef part = args[3];

      List<ObjectRef> ret;
      ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
        ATEN_DTYPE_SWITCH(in_values->dtype, DType, "values", {
          auto result =
              SparsePush<IdType, DType>(comm, in_idx, in_values, part);
          ret.push_back(Value(MakeValue(result.first)));
          ret.push_back(Value(MakeValue(result.second)));
        });
      });

      *rv = ret;
719
720
721
    });

DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPull")
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NCCLCommunicatorRef comm = args[0];
      // the indexes this process is requesting from others
      IdArray req_idx = args[1];

      // the tensor this process has to fulfill other requests
      NDArray tensor = args[2];
      NDArrayPartitionRef part = args[3];

      ATEN_ID_TYPE_SWITCH(req_idx->dtype, IdType, {
        ATEN_DTYPE_SWITCH(tensor->dtype, DType, "values", {
          *rv = SparsePull<IdType, DType>(comm, req_idx, tensor, part);
        });
      });
736
737
    });

738
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLHasSupport")
739
740
741
742
743
744
745
    .set_body([](DGLArgs args, DGLRetValue* rv) {
#ifndef DGL_USE_NCCL
      return false;
#else
      return true;
#endif
    });
746
747
748
749

}  // namespace cuda
}  // namespace runtime
}  // namespace dgl