"vscode:/vscode.git/clone" did not exist on "0407c3e7d0ed844baf3c0b09d9b231d09445e5d8"
nccl_api.h 4.19 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 *  Copyright (c) 2021-2022 by Contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
16
17
 * @file nccl_api.h
 * @brief Wrapper around NCCL routines.
18
19
20
21
22
 */

#ifndef DGL_RUNTIME_CUDA_NCCL_API_H_
#define DGL_RUNTIME_CUDA_NCCL_API_H_

23
#ifdef DGL_USE_NCCL
24
#include "nccl.h"
25
26
27
28
#else
// if not compiling with NCCL, this class will only support communicators of
// size 1.
#define NCCL_UNIQUE_ID_BYTES 128
29
30
31
typedef struct {
  char internal[NCCL_UNIQUE_ID_BYTES];
} ncclUniqueId;
32
33
typedef int ncclComm_t;
#endif
34
35

#include <dgl/runtime/object.h>
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <string>

namespace dgl {
namespace runtime {
namespace cuda {

class NCCLUniqueId : public runtime::Object {
 public:
  NCCLUniqueId();

  static constexpr const char* _type_key = "cuda.NCCLUniqueId";
  DGL_DECLARE_OBJECT_TYPE_INFO(NCCLUniqueId, Object);

  ncclUniqueId Get() const;

  std::string ToString() const;

  void FromString(const std::string& str);

 private:
  ncclUniqueId id_;
};

DGL_DEFINE_OBJECT_REF(NCCLUniqueIdRef, NCCLUniqueId);

class NCCLCommunicator : public runtime::Object {
 public:
64
  NCCLCommunicator(int size, int rank, ncclUniqueId id);
65
66
67
68
69

  ~NCCLCommunicator();

  // disable copying
  NCCLCommunicator(const NCCLCommunicator& other) = delete;
70
  NCCLCommunicator& operator=(const NCCLCommunicator& other);
71
72
73
74
75
76
77
78
79
80
81

  ncclComm_t Get();

  /**
   * @brief Perform an all-to-all communication.
   *
   * @param send The continous array of data to send.
   * @param recv The continous array of data to recieve.
   * @param count The size of data to send to each rank.
   * @param stream The stream to operate on.
   */
82
  template <typename IdType>
83
  void AllToAll(
84
      const IdType* send, IdType* recv, int64_t count, cudaStream_t stream);
85
86
87
88
89
90
91
92
93
94
95
96

  /**
   * @brief Perform an all-to-all variable sized communication.
   *
   * @tparam DType The type of value to send.
   * @param send The arrays of data to send.
   * @param send_prefix The prefix of each array to send.
   * @param recv The arrays of data to recieve.
   * @param recv_prefix The prefix of each array to recieve.
   * @param type The type of data to send.
   * @param stream The stream to operate on.
   */
97
  template <typename DType>
98
  void AllToAllV(
99
100
      const DType* const send, const int64_t* send_prefix, DType* const recv,
      const int64_t* recv_prefix, cudaStream_t stream);
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

  /**
   * @brief Perform an all-to-all with sparse data (idx and value pairs). By
   * necessity, the sizes of each message are variable.
   *
   * @tparam IdType The type of index.
   * @tparam DType The type of value.
   * @param send_idx The set of indexes to send on the device.
   * @param send_value The set of values to send on the device.
   * @param num_feat The number of values per index.
   * @param send_prefix The exclusive prefix sum of elements to send on the
   * host.
   * @param recv_idx The set of indexes to recieve on the device.
   * @param recv_value The set of values to recieve on the device.
   * @param recv_prefix The exclusive prefix sum of the number of elements to
   * recieve on the host.
   * @param stream The stream to communicate on.
   */
119
  template <typename IdType, typename DType>
120
  void SparseAllToAll(
121
122
123
      const IdType* send_idx, const DType* send_value, const int64_t num_feat,
      const int64_t* send_prefix, IdType* recv_idx, DType* recv_value,
      const int64_t* recv_prefix, cudaStream_t stream);
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

  int size() const;

  int rank() const;

  static constexpr const char* _type_key = "cuda.NCCLCommunicator";
  DGL_DECLARE_OBJECT_TYPE_INFO(NCCLCommunicator, Object);

 private:
  ncclComm_t comm_;
  int size_;
  int rank_;
};

DGL_DEFINE_OBJECT_REF(NCCLCommunicatorRef, NCCLCommunicator);

}  // namespace cuda
}  // namespace runtime
}  // namespace dgl

#endif  // DGL_RUNTIME_CUDA_NCCL_API_H_