network.h 11 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#ifndef LIGHTGBM_NETWORK_H_
#define LIGHTGBM_NETWORK_H_

#include <LightGBM/config.h>
5
6
#include <LightGBM/meta.h>
#include <LightGBM/utils/log.h>
Guolin Ke's avatar
Guolin Ke committed
7
8

#include <functional>
Guolin Ke's avatar
Guolin Ke committed
9
#include <memory>
10
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
11
12
13
14
15
16

namespace LightGBM {

/*! \brief forward declaration */
class Linkers;

Qiwei Ye's avatar
Qiwei Ye committed
17
/*! \brief The network structure for all_gather */
Guolin Ke's avatar
Guolin Ke committed
18
class BruckMap {
Nikita Titov's avatar
Nikita Titov committed
19
 public:
Guolin Ke's avatar
Guolin Ke committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  /*! \brief The communication times for one all gather operation */
  int k;
  /*! \brief in_ranks[i] means the incomming rank on i-th communication */
  std::vector<int> in_ranks;
  /*! \brief out_ranks[i] means the out rank on i-th communication */
  std::vector<int> out_ranks;
  BruckMap();
  explicit BruckMap(int n);
  /*!
  * \brief Create the object of bruck map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of bruck map
  */
  static BruckMap Construct(int rank, int num_machines);
};

Guolin Ke's avatar
Guolin Ke committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/*!
* \brief node type on recursive halving algorithm
*        When number of machines is not power of 2, need group machines into power of 2 group.
*        And we can let each group has at most 2 machines.
*        if the group only has 1 machine. this machine is the normal node
*        if the group has 2 machines, this group will have two type of nodes, one is the leader.
*        leader will represent this group and communication with others.
*/
enum RecursiveHalvingNodeType {
  Normal,  // normal node, 1 group only have 1 machine
  GroupLeader,  // leader of group when number of machines in this group is 2.
  Other  // non-leader machines in group
};

Guolin Ke's avatar
Guolin Ke committed
51
52
/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
Nikita Titov's avatar
Nikita Titov committed
53
 public:
Guolin Ke's avatar
Guolin Ke committed
54
55
  /*! \brief Communication times for one recursize halving algorithm  */
  int k;
Guolin Ke's avatar
Guolin Ke committed
56
57
58
59
  /*! \brief Node type */
  RecursiveHalvingNodeType type;
  bool is_power_of_2;
  int neighbor;
Guolin Ke's avatar
Guolin Ke committed
60
61
62
63
64
65
66
67
68
69
70
71
72
  /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
  std::vector<int> ranks;
  /*! \brief  send_block_start[i] means send block start index at i-th communication*/
  std::vector<int> send_block_start;
  /*! \brief  send_block_start[i] means send block size at i-th communication*/
  std::vector<int> send_block_len;
  /*! \brief  send_block_start[i] means recv block start index at i-th communication*/
  std::vector<int> recv_block_start;
  /*! \brief  send_block_start[i] means recv block size  at i-th communication*/
  std::vector<int> recv_block_len;

  RecursiveHalvingMap();

Guolin Ke's avatar
Guolin Ke committed
73
  RecursiveHalvingMap(int k, RecursiveHalvingNodeType _type, bool _is_power_of_2);
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
79
80
81
82
83
84
85

  /*!
  * \brief Create the object of recursive halving map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of recursive halving map
  */
  static RecursiveHalvingMap Construct(int rank, int num_machines);
};

/*! \brief A static class that contains some collective communication algorithm */
class Network {
Nikita Titov's avatar
Nikita Titov committed
86
 public:
Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
  /*!
  * \brief Initialize
  * \param config Config of network setting
  */
Guolin Ke's avatar
Guolin Ke committed
91
  static void Init(Config config);
92
93
94
95
  /*!
  * \brief Initialize
  */
  static void Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
100
101
102
103
  /*! \brief Free this static class */
  static void Dispose();
  /*! \brief Get rank of this machine */
  static inline int rank();
  /*! \brief Get total number of machines */
  static inline int num_machines();

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
104
  * \brief Perform all_reduce. if data size is small,
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
111
           will perform AllreduceByAllGather, else with call ReduceScatter followed allgather
  * \param input Input data
  * \param input_size The size of input data
  * \param type_size The size of one object in the reduce function
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
112
113
  static void Allreduce(char* input, comm_size_t input_size, int type_size,
                        char* output, const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
114
115

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
116
  * \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
Guolin Ke's avatar
Guolin Ke committed
117
118
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
119
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
120
121
122
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
123
124
  static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
                                   const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
125
126

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
127
128
129
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(send_size * number_machine)
  *        It can be used when all nodes have same input size.
Guolin Ke's avatar
Guolin Ke committed
130
131
132
133
  * \param input Input data
  * \param send_size The size of input data
  * \param output Output result
  */
Guolin Ke's avatar
Guolin Ke committed
134
  static void Allgather(char* input, comm_size_t send_size, char* output);
Guolin Ke's avatar
Guolin Ke committed
135
136

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
137
138
139
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(all_size)
  *        It can be used when nodes have different input size.
Guolin Ke's avatar
Guolin Ke committed
140
141
142
143
  * \param input Input data
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
144
  * \param all_size The size of output data
Guolin Ke's avatar
Guolin Ke committed
145
  */
Guolin Ke's avatar
Guolin Ke committed
146
  static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
Guolin Ke's avatar
Guolin Ke committed
147
148

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
149
150
  * \brief Perform reduce scatter by using recursive halving algorithm. 
           Communication times is O(log(n)), and communication cost is O(input_size)
Guolin Ke's avatar
Guolin Ke committed
151
152
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
153
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
154
155
156
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
157
  * \param output_size size of output data
Guolin Ke's avatar
Guolin Ke committed
158
159
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
160
161
162
  static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                            const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
163

Guolin Ke's avatar
Guolin Ke committed
164
165
166
167
168
169
  template<class T>
  static T GlobalSyncUpByMin(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
170
171
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 < *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

  template<class T>
  static T GlobalSyncUpByMax(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
194
195
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 > *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
  template<class T>
  static T GlobalSyncUpByMean(T& local) {
    T global = (T)0;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
              [](const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return static_cast<T>(global / num_machines_);
  }

  template<class T>
  static void GlobalSum(std::vector<T>& local) {
236
    std::vector<T> global(local.size(), 0);
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    Allreduce(reinterpret_cast<char*>(local.data()),
              static_cast<comm_size_t>(sizeof(T) * local.size()), sizeof(T),
              reinterpret_cast<char*>(global.data()),
              [](const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    for (size_t i = 0; i < local.size(); ++i) {
      local[i] = global[i];
    }
  }

Nikita Titov's avatar
Nikita Titov committed
258
 private:
Guolin Ke's avatar
Guolin Ke committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
  static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void ReduceScatterRecursiveHalving(char* input, comm_size_t input_size, int type_size,
                                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                                            const ReduceFunction& reducer);

  static void ReduceScatterRing(char* input, comm_size_t input_size, int type_size,
                                const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                                const ReduceFunction& reducer);

Guolin Ke's avatar
Guolin Ke committed
273
  /*! \brief Number of all machines */
274
  static THREAD_LOCAL int num_machines_;
Guolin Ke's avatar
Guolin Ke committed
275
  /*! \brief Rank of local machine */
276
  static THREAD_LOCAL int rank_;
Guolin Ke's avatar
Guolin Ke committed
277
  /*! \brief The network interface, provide send/recv functions  */
278
  static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
Guolin Ke's avatar
Guolin Ke committed
279
  /*! \brief Bruck map for all gather algorithm*/
280
  static THREAD_LOCAL BruckMap bruck_map_;
Guolin Ke's avatar
Guolin Ke committed
281
  /*! \brief Recursive halving map for reduce scatter */
282
  static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
Guolin Ke's avatar
Guolin Ke committed
283
  /*! \brief Buffer to store block start index */
Guolin Ke's avatar
Guolin Ke committed
284
  static THREAD_LOCAL std::vector<comm_size_t> block_start_;
Guolin Ke's avatar
Guolin Ke committed
285
  /*! \brief Buffer to store block size */
Guolin Ke's avatar
Guolin Ke committed
286
  static THREAD_LOCAL std::vector<comm_size_t> block_len_;
Guolin Ke's avatar
Guolin Ke committed
287
  /*! \brief Buffer  */
288
  static THREAD_LOCAL std::vector<char> buffer_;
Guolin Ke's avatar
Guolin Ke committed
289
  /*! \brief Size of buffer_ */
Guolin Ke's avatar
Guolin Ke committed
290
  static THREAD_LOCAL comm_size_t buffer_size_;
ww's avatar
ww committed
291
  /*! \brief Funcs*/
292
293
  static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
  static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
298
299
300
301
302
303
304
305
};

inline int Network::rank() {
  return rank_;
}

inline int Network::num_machines() {
  return num_machines_;
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
306
#endif   // LightGBM_NETWORK_H_