network.h 11 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
#ifndef LIGHTGBM_NETWORK_H_
#define LIGHTGBM_NETWORK_H_

#include <LightGBM/utils/log.h>

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <functional>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
11
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
12
13
14
15
16
17

namespace LightGBM {

/*! \brief forward declaration */
class Linkers;

Qiwei Ye's avatar
Qiwei Ye committed
18
/*! \brief The network structure for all_gather */
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class BruckMap {
public:
  /*! \brief The communication times for one all gather operation */
  int k;
  /*! \brief in_ranks[i] means the incomming rank on i-th communication */
  std::vector<int> in_ranks;
  /*! \brief out_ranks[i] means the out rank on i-th communication */
  std::vector<int> out_ranks;
  BruckMap();
  explicit BruckMap(int n);
  /*!
  * \brief Create the object of bruck map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of bruck map
  */
  static BruckMap Construct(int rank, int num_machines);
};

Guolin Ke's avatar
Guolin Ke committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
/*!
* \brief node type on recursive halving algorithm
*        When number of machines is not power of 2, need group machines into power of 2 group.
*        And we can let each group has at most 2 machines.
*        if the group only has 1 machine. this machine is the normal node
*        if the group has 2 machines, this group will have two type of nodes, one is the leader.
*        leader will represent this group and communication with others.
*/
enum RecursiveHalvingNodeType {
  Normal,  // normal node, 1 group only have 1 machine
  GroupLeader,  // leader of group when number of machines in this group is 2.
  Other  // non-leader machines in group
};

Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
public:
  /*! \brief Communication times for one recursize halving algorithm  */
  int k;
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
  /*! \brief Node type */
  RecursiveHalvingNodeType type;
  bool is_power_of_2;
  int neighbor;
Guolin Ke's avatar
Guolin Ke committed
61
62
63
64
65
66
67
68
69
70
71
72
73
  /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
  std::vector<int> ranks;
  /*! \brief  send_block_start[i] means send block start index at i-th communication*/
  std::vector<int> send_block_start;
  /*! \brief  send_block_start[i] means send block size at i-th communication*/
  std::vector<int> send_block_len;
  /*! \brief  send_block_start[i] means recv block start index at i-th communication*/
  std::vector<int> recv_block_start;
  /*! \brief  send_block_start[i] means recv block size  at i-th communication*/
  std::vector<int> recv_block_len;

  RecursiveHalvingMap();

Guolin Ke's avatar
Guolin Ke committed
74
  RecursiveHalvingMap(int k, RecursiveHalvingNodeType _type, bool _is_power_of_2);
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

  /*!
  * \brief Create the object of recursive halving map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of recursive halving map
  */
  static RecursiveHalvingMap Construct(int rank, int num_machines);
};

/*! \brief A static class that contains some collective communication algorithm */
class Network {
public:
  /*!
  * \brief Initialize
  * \param config Config of network setting
  */
  static void Init(NetworkConfig config);
93
94
95
96
  /*!
  * \brief Initialize
  */
  static void Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
101
102
103
104
  /*! \brief Free this static class */
  static void Dispose();
  /*! \brief Get rank of this machine */
  static inline int rank();
  /*! \brief Get total number of machines */
  static inline int num_machines();

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
105
  * \brief Perform all_reduce. if data size is small,
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
112
           will perform AllreduceByAllGather, else with call ReduceScatter followed allgather
  * \param input Input data
  * \param input_size The size of input data
  * \param type_size The size of one object in the reduce function
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
113
114
  static void Allreduce(char* input, comm_size_t input_size, int type_size,
                        char* output, const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
115
116

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
117
  * \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
Guolin Ke's avatar
Guolin Ke committed
118
119
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
120
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
121
122
123
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
124
125
  static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
                                   const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
126
127

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
128
129
130
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(send_size * number_machine)
  *        It can be used when all nodes have same input size.
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
  * \param input Input data
  * \param send_size The size of input data
  * \param output Output result
  */
Guolin Ke's avatar
Guolin Ke committed
135
  static void Allgather(char* input, comm_size_t send_size, char* output);
Guolin Ke's avatar
Guolin Ke committed
136
137

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
138
139
140
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(all_size)
  *        It can be used when nodes have different input size.
Guolin Ke's avatar
Guolin Ke committed
141
142
143
144
  * \param input Input data
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
145
  * \param all_size The size of output data
Guolin Ke's avatar
Guolin Ke committed
146
  */
Guolin Ke's avatar
Guolin Ke committed
147
  static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
Guolin Ke's avatar
Guolin Ke committed
148
149

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
150
151
  * \brief Perform reduce scatter by using recursive halving algorithm. 
           Communication times is O(log(n)), and communication cost is O(input_size)
Guolin Ke's avatar
Guolin Ke committed
152
153
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
154
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
155
156
157
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
158
  * \param output_size size of output data
Guolin Ke's avatar
Guolin Ke committed
159
160
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
161
162
163
  static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                            const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
164

Guolin Ke's avatar
Guolin Ke committed
165
166
167
168
169
170
  template<class T>
  static T GlobalSyncUpByMin(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
171
172
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 < *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

  template<class T>
  static T GlobalSyncUpByMax(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
195
196
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 > *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
  template<class T>
  static T GlobalSyncUpByMean(T& local) {
    T global = (T)0;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
              [](const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return static_cast<T>(global / num_machines_);
  }

  template<class T>
  static void GlobalSum(std::vector<T>& local) {
    std::vector<T> global;
    Allreduce(reinterpret_cast<char*>(local.data()),
              static_cast<comm_size_t>(sizeof(T) * local.size()), sizeof(T),
              reinterpret_cast<char*>(global.data()),
              [](const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    for (size_t i = 0; i < local.size(); ++i) {
      local[i] = global[i];
    }
  }

Guolin Ke's avatar
Guolin Ke committed
259
private:
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

  static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void ReduceScatterRecursiveHalving(char* input, comm_size_t input_size, int type_size,
                                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                                            const ReduceFunction& reducer);

  static void ReduceScatterRing(char* input, comm_size_t input_size, int type_size,
                                const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                                const ReduceFunction& reducer);

Guolin Ke's avatar
Guolin Ke committed
275
  /*! \brief Number of all machines */
276
  static THREAD_LOCAL int num_machines_;
Guolin Ke's avatar
Guolin Ke committed
277
  /*! \brief Rank of local machine */
278
  static THREAD_LOCAL int rank_;
Guolin Ke's avatar
Guolin Ke committed
279
  /*! \brief The network interface, provide send/recv functions  */
280
  static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
Guolin Ke's avatar
Guolin Ke committed
281
  /*! \brief Bruck map for all gather algorithm*/
282
  static THREAD_LOCAL BruckMap bruck_map_;
Guolin Ke's avatar
Guolin Ke committed
283
  /*! \brief Recursive halving map for reduce scatter */
284
  static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
Guolin Ke's avatar
Guolin Ke committed
285
  /*! \brief Buffer to store block start index */
Guolin Ke's avatar
Guolin Ke committed
286
  static THREAD_LOCAL std::vector<comm_size_t> block_start_;
Guolin Ke's avatar
Guolin Ke committed
287
  /*! \brief Buffer to store block size */
Guolin Ke's avatar
Guolin Ke committed
288
  static THREAD_LOCAL std::vector<comm_size_t> block_len_;
Guolin Ke's avatar
Guolin Ke committed
289
  /*! \brief Buffer  */
290
  static THREAD_LOCAL std::vector<char> buffer_;
Guolin Ke's avatar
Guolin Ke committed
291
  /*! \brief Size of buffer_ */
Guolin Ke's avatar
Guolin Ke committed
292
  static THREAD_LOCAL comm_size_t buffer_size_;
ww's avatar
ww committed
293
  /*! \brief Funcs*/
294
295
  static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
  static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
Guolin Ke's avatar
Guolin Ke committed
296
297
298
299
300
301
302
303
304
305
306
307
};

inline int Network::rank() {
  return rank_;
}

inline int Network::num_machines() {
  return num_machines_;
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
308
#endif   // LightGBM_NETWORK_H_