network.h 11 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
#ifndef LIGHTGBM_NETWORK_H_
#define LIGHTGBM_NETWORK_H_

#include <LightGBM/utils/log.h>

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <functional>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
11
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
12
13
14
15
16
17

namespace LightGBM {

/*! \brief forward declaration */
class Linkers;

Qiwei Ye's avatar
Qiwei Ye committed
18
/*! \brief The network structure for all_gather */
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class BruckMap {
public:
  /*! \brief The communication times for one all gather operation */
  int k;
  /*! \brief in_ranks[i] means the incomming rank on i-th communication */
  std::vector<int> in_ranks;
  /*! \brief out_ranks[i] means the out rank on i-th communication */
  std::vector<int> out_ranks;
  BruckMap();
  explicit BruckMap(int n);
  /*!
  * \brief Create the object of bruck map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of bruck map
  */
  static BruckMap Construct(int rank, int num_machines);
};

Guolin Ke's avatar
Guolin Ke committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
/*!
* \brief node type on recursive halving algorithm
*        When number of machines is not power of 2, need group machines into power of 2 group.
*        And we can let each group has at most 2 machines.
*        if the group only has 1 machine. this machine is the normal node
*        if the group has 2 machines, this group will have two type of nodes, one is the leader.
*        leader will represent this group and communication with others.
*/
enum RecursiveHalvingNodeType {
  Normal,  // normal node, 1 group only have 1 machine
  GroupLeader,  // leader of group when number of machines in this group is 2.
  Other  // non-leader machines in group
};

Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
public:
  /*! \brief Communication times for one recursize halving algorithm  */
  int k;
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
  /*! \brief Node type */
  RecursiveHalvingNodeType type;
  bool is_power_of_2;
  int neighbor;
Guolin Ke's avatar
Guolin Ke committed
61
62
63
64
65
66
67
68
69
70
71
72
73
  /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
  std::vector<int> ranks;
  /*! \brief  send_block_start[i] means send block start index at i-th communication*/
  std::vector<int> send_block_start;
  /*! \brief  send_block_start[i] means send block size at i-th communication*/
  std::vector<int> send_block_len;
  /*! \brief  send_block_start[i] means recv block start index at i-th communication*/
  std::vector<int> recv_block_start;
  /*! \brief  send_block_start[i] means recv block size  at i-th communication*/
  std::vector<int> recv_block_len;

  RecursiveHalvingMap();

Guolin Ke's avatar
Guolin Ke committed
74
  RecursiveHalvingMap(int k, RecursiveHalvingNodeType _type, bool _is_power_of_2);
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

  /*!
  * \brief Create the object of recursive halving map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of recursive halving map
  */
  static RecursiveHalvingMap Construct(int rank, int num_machines);
};

/*! \brief A static class that contains some collective communication algorithm */
class Network {
public:
  /*!
  * \brief Initialize
  * \param config Config of network setting
  */
Guolin Ke's avatar
Guolin Ke committed
92
  static void Init(Config config);
93
94
95
96
  /*!
  * \brief Initialize
  */
  static void Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
101
102
103
104
  /*! \brief Free this static class */
  static void Dispose();
  /*! \brief Get rank of this machine */
  static inline int rank();
  /*! \brief Get total number of machines */
  static inline int num_machines();

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
105
  * \brief Perform all_reduce. if data size is small,
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
112
           will perform AllreduceByAllGather, else with call ReduceScatter followed allgather
  * \param input Input data
  * \param input_size The size of input data
  * \param type_size The size of one object in the reduce function
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
113
114
  static void Allreduce(char* input, comm_size_t input_size, int type_size,
                        char* output, const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
115
116

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
117
  * \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
Guolin Ke's avatar
Guolin Ke committed
118
119
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
120
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
121
122
123
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
124
125
  static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
                                   const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
126
127

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
128
129
130
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(send_size * number_machine)
  *        It can be used when all nodes have same input size.
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
  * \param input Input data
  * \param send_size The size of input data
  * \param output Output result
  */
Guolin Ke's avatar
Guolin Ke committed
135
  static void Allgather(char* input, comm_size_t send_size, char* output);
Guolin Ke's avatar
Guolin Ke committed
136
137

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
138
139
140
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(all_size)
  *        It can be used when nodes have different input size.
Guolin Ke's avatar
Guolin Ke committed
141
142
143
144
  * \param input Input data
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
145
  * \param all_size The size of output data
Guolin Ke's avatar
Guolin Ke committed
146
  */
Guolin Ke's avatar
Guolin Ke committed
147
  static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
Guolin Ke's avatar
Guolin Ke committed
148
149

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
150
151
  * \brief Perform reduce scatter by using recursive halving algorithm. 
           Communication times is O(log(n)), and communication cost is O(input_size)
Guolin Ke's avatar
Guolin Ke committed
152
153
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
154
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
155
156
157
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
158
  * \param output_size size of output data
Guolin Ke's avatar
Guolin Ke committed
159
160
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
161
162
163
  static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                            const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
164

Guolin Ke's avatar
Guolin Ke committed
165
166
167
168
169
170
  template<class T>
  static T GlobalSyncUpByMin(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
171
172
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 < *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

  template<class T>
  static T GlobalSyncUpByMax(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
195
196
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 > *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
  template<class T>
  static T GlobalSyncUpByMean(T& local) {
    T global = (T)0;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
              [](const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return static_cast<T>(global / num_machines_);
  }

  template<class T>
  static void GlobalSum(std::vector<T>& local) {
237
    std::vector<T> global(local.size(), 0);
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    Allreduce(reinterpret_cast<char*>(local.data()),
              static_cast<comm_size_t>(sizeof(T) * local.size()), sizeof(T),
              reinterpret_cast<char*>(global.data()),
              [](const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        *p2 += *p1;
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    for (size_t i = 0; i < local.size(); ++i) {
      local[i] = global[i];
    }
  }

Guolin Ke's avatar
Guolin Ke committed
259
private:
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
  static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void ReduceScatterRecursiveHalving(char* input, comm_size_t input_size, int type_size,
                                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                                            const ReduceFunction& reducer);

  static void ReduceScatterRing(char* input, comm_size_t input_size, int type_size,
                                const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                                const ReduceFunction& reducer);

Guolin Ke's avatar
Guolin Ke committed
274
  /*! \brief Number of all machines */
275
  static THREAD_LOCAL int num_machines_;
Guolin Ke's avatar
Guolin Ke committed
276
  /*! \brief Rank of local machine */
277
  static THREAD_LOCAL int rank_;
Guolin Ke's avatar
Guolin Ke committed
278
  /*! \brief The network interface, provide send/recv functions  */
279
  static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
Guolin Ke's avatar
Guolin Ke committed
280
  /*! \brief Bruck map for all gather algorithm*/
281
  static THREAD_LOCAL BruckMap bruck_map_;
Guolin Ke's avatar
Guolin Ke committed
282
  /*! \brief Recursive halving map for reduce scatter */
283
  static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
Guolin Ke's avatar
Guolin Ke committed
284
  /*! \brief Buffer to store block start index */
Guolin Ke's avatar
Guolin Ke committed
285
  static THREAD_LOCAL std::vector<comm_size_t> block_start_;
Guolin Ke's avatar
Guolin Ke committed
286
  /*! \brief Buffer to store block size */
Guolin Ke's avatar
Guolin Ke committed
287
  static THREAD_LOCAL std::vector<comm_size_t> block_len_;
Guolin Ke's avatar
Guolin Ke committed
288
  /*! \brief Buffer  */
289
  static THREAD_LOCAL std::vector<char> buffer_;
Guolin Ke's avatar
Guolin Ke committed
290
  /*! \brief Size of buffer_ */
Guolin Ke's avatar
Guolin Ke committed
291
  static THREAD_LOCAL comm_size_t buffer_size_;
ww's avatar
ww committed
292
  /*! \brief Funcs*/
293
294
  static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
  static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
299
300
301
302
303
304
305
306
};

inline int Network::rank() {
  return rank_;
}

inline int Network::num_machines() {
  return num_machines_;
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
307
#endif   // LightGBM_NETWORK_H_