network.h 8.53 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
#ifndef LIGHTGBM_NETWORK_H_
#define LIGHTGBM_NETWORK_H_

#include <LightGBM/utils/log.h>

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <functional>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
11
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
12
13
14
15
16
17

namespace LightGBM {

/*! \brief forward declaration */
class Linkers;

Qiwei Ye's avatar
Qiwei Ye committed
18
/*! \brief The network structure for all_gather */
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class BruckMap {
public:
  /*! \brief The communication times for one all gather operation */
  int k;
  /*! \brief in_ranks[i] means the incomming rank on i-th communication */
  std::vector<int> in_ranks;
  /*! \brief out_ranks[i] means the out rank on i-th communication */
  std::vector<int> out_ranks;
  BruckMap();
  explicit BruckMap(int n);
  /*!
  * \brief Create the object of bruck map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of bruck map
  */
  static BruckMap Construct(int rank, int num_machines);
};

/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
public:
41
42
  /*! \brief If number workers is powers of 2  */
  bool is_prof2;
Guolin Ke's avatar
Guolin Ke committed
43
44
  /*! \brief Communication times for one recursize halving algorithm  */
  int k;
45
46
47
48
  /*! \brief Number workers subtract powers of 2  */
  int num_remain;
  /*! \brief Virtual rank for recursize halving algorithm  */
  int virtual_rank;
Guolin Ke's avatar
Guolin Ke committed
49
50
51
52
53
54
55
56
57
58
59
60
61
  /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
  std::vector<int> ranks;
  /*! \brief  send_block_start[i] means send block start index at i-th communication*/
  std::vector<int> send_block_start;
  /*! \brief  send_block_start[i] means send block size at i-th communication*/
  std::vector<int> send_block_len;
  /*! \brief  send_block_start[i] means recv block start index at i-th communication*/
  std::vector<int> recv_block_start;
  /*! \brief  send_block_start[i] means recv block size  at i-th communication*/
  std::vector<int> recv_block_len;

  RecursiveHalvingMap();

62
  RecursiveHalvingMap(int k, int num_remain, int virtual_rank, bool is_prof2);
Guolin Ke's avatar
Guolin Ke committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

  /*!
  * \brief Create the object of recursive halving map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of recursive halving map
  */
  static RecursiveHalvingMap Construct(int rank, int num_machines);
};

/*! \brief A static class that contains some collective communication algorithm */
class Network {
public:
  /*!
  * \brief Initialize
  * \param config Config of network setting
  */
  static void Init(NetworkConfig config);
81
82
83
84
  /*!
  * \brief Initialize
  */
  static void Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
90
91
92
  /*! \brief Free this static class */
  static void Dispose();
  /*! \brief Get rank of this machine */
  static inline int rank();
  /*! \brief Get total number of machines */
  static inline int num_machines();

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
93
  * \brief Perform all_reduce. if data size is small,
Guolin Ke's avatar
Guolin Ke committed
94
95
96
97
98
99
100
           will perform AllreduceByAllGather, else with call ReduceScatter followed allgather
  * \param input Input data
  * \param input_size The size of input data
  * \param type_size The size of one object in the reduce function
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
101
102
  static void Allreduce(char* input, comm_size_t input_size, int type_size,
                        char* output, const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
103
104

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
105
  * \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
Guolin Ke's avatar
Guolin Ke committed
106
107
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
108
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
109
110
111
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
112
113
  static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
                                   const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
114
115

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
116
117
118
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(send_size * number_machine)
  *        It can be used when all nodes have same input size.
Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
  * \param input Input data
  * \param send_size The size of input data
  * \param output Output result
  */
Guolin Ke's avatar
Guolin Ke committed
123
  static void Allgather(char* input, comm_size_t send_size, char* output);
Guolin Ke's avatar
Guolin Ke committed
124
125

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
126
127
128
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(all_size)
  *        It can be used when nodes have different input size.
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
  * \param input Input data
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
133
  * \param all_size The size of output data
Guolin Ke's avatar
Guolin Ke committed
134
  */
Guolin Ke's avatar
Guolin Ke committed
135
  static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
Guolin Ke's avatar
Guolin Ke committed
136

Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
142
143
  static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

  static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);


Guolin Ke's avatar
Guolin Ke committed
144
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
145
146
  * \brief Perform reduce scatter by using recursive halving algorithm. 
           Communication times is O(log(n)), and communication cost is O(input_size)
Guolin Ke's avatar
Guolin Ke committed
147
148
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
149
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
150
151
152
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
153
  * \param output_size size of output data
Guolin Ke's avatar
Guolin Ke committed
154
155
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
156
157
158
  static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                            const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
159

Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
165
  template<class T>
  static T GlobalSyncUpByMin(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
166
167
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 < *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

  template<class T>
  static T GlobalSyncUpByMax(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
190
191
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 > *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

Guolin Ke's avatar
Guolin Ke committed
208
209
private:
  /*! \brief Number of all machines */
210
  static THREAD_LOCAL int num_machines_;
Guolin Ke's avatar
Guolin Ke committed
211
  /*! \brief Rank of local machine */
212
  static THREAD_LOCAL int rank_;
Guolin Ke's avatar
Guolin Ke committed
213
  /*! \brief The network interface, provide send/recv functions  */
214
  static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
Guolin Ke's avatar
Guolin Ke committed
215
  /*! \brief Bruck map for all gather algorithm*/
216
  static THREAD_LOCAL BruckMap bruck_map_;
Guolin Ke's avatar
Guolin Ke committed
217
  /*! \brief Recursive halving map for reduce scatter */
218
  static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
Guolin Ke's avatar
Guolin Ke committed
219
  /*! \brief Buffer to store block start index */
Guolin Ke's avatar
Guolin Ke committed
220
  static THREAD_LOCAL std::vector<comm_size_t> block_start_;
Guolin Ke's avatar
Guolin Ke committed
221
  /*! \brief Buffer to store block size */
Guolin Ke's avatar
Guolin Ke committed
222
  static THREAD_LOCAL std::vector<comm_size_t> block_len_;
Guolin Ke's avatar
Guolin Ke committed
223
  /*! \brief Buffer  */
224
  static THREAD_LOCAL std::vector<char> buffer_;
Guolin Ke's avatar
Guolin Ke committed
225
  /*! \brief Size of buffer_ */
Guolin Ke's avatar
Guolin Ke committed
226
  static THREAD_LOCAL comm_size_t buffer_size_;
ww's avatar
ww committed
227
  /*! \brief Funcs*/
228
229
  static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
  static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
Guolin Ke's avatar
Guolin Ke committed
230
231
232
233
234
235
236
237
238
239
240
241
};

inline int Network::rank() {
  return rank_;
}

inline int Network::num_machines() {
  return num_machines_;
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
242
#endif   // LightGBM_NETWORK_H_