network.h 7.88 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
#ifndef LIGHTGBM_NETWORK_H_
#define LIGHTGBM_NETWORK_H_

#include <LightGBM/utils/log.h>

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <functional>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
11
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
12
13
14
15
16
17

namespace LightGBM {

/*! \brief forward declaration */
class Linkers;

Qiwei Ye's avatar
Qiwei Ye committed
18
/*! \brief The network structure for all_gather */
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class BruckMap {
public:
  /*! \brief The communication times for one all gather operation */
  int k;
  /*! \brief in_ranks[i] means the incomming rank on i-th communication */
  std::vector<int> in_ranks;
  /*! \brief out_ranks[i] means the out rank on i-th communication */
  std::vector<int> out_ranks;
  BruckMap();
  explicit BruckMap(int n);
  /*!
  * \brief Create the object of bruck map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of bruck map
  */
  static BruckMap Construct(int rank, int num_machines);
};

/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
public:
Guolin Ke's avatar
Guolin Ke committed
41
  bool need_pairwise;
Guolin Ke's avatar
Guolin Ke committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  /*! \brief Communication times for one recursize halving algorithm  */
  int k;
  /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
  std::vector<int> ranks;
  /*! \brief  send_block_start[i] means send block start index at i-th communication*/
  std::vector<int> send_block_start;
  /*! \brief  send_block_start[i] means send block size at i-th communication*/
  std::vector<int> send_block_len;
  /*! \brief  send_block_start[i] means recv block start index at i-th communication*/
  std::vector<int> recv_block_start;
  /*! \brief  send_block_start[i] means recv block size  at i-th communication*/
  std::vector<int> recv_block_len;

  RecursiveHalvingMap();

Guolin Ke's avatar
Guolin Ke committed
57
  RecursiveHalvingMap(int k, bool in_need_pairwise);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

  /*!
  * \brief Create the object of recursive halving map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of recursive halving map
  */
  static RecursiveHalvingMap Construct(int rank, int num_machines);
};

/*! \brief A static class that contains some collective communication algorithm */
class Network {
public:
  /*!
  * \brief Initialize
  * \param config Config of network setting
  */
  static void Init(NetworkConfig config);
76
77
78
79
  /*!
  * \brief Initialize
  */
  static void Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
Guolin Ke's avatar
Guolin Ke committed
80
81
82
83
84
85
86
87
  /*! \brief Free this static class */
  static void Dispose();
  /*! \brief Get rank of this machine */
  static inline int rank();
  /*! \brief Get total number of machines */
  static inline int num_machines();

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
88
  * \brief Perform all_reduce. if data size is small,
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
93
94
95
           will perform AllreduceByAllGather, else with call ReduceScatter followed allgather
  * \param input Input data
  * \param input_size The size of input data
  * \param type_size The size of one object in the reduce function
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
96
97
  static void Allreduce(char* input, comm_size_t input_size, int type_size,
                        char* output, const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
98
99

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
100
  * \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
Guolin Ke's avatar
Guolin Ke committed
101
102
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
103
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
104
105
106
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
107
108
  static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
                                   const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
109
110

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
111
112
113
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(send_size * number_machine)
  *        It can be used when all nodes have same input size.
Guolin Ke's avatar
Guolin Ke committed
114
115
116
117
  * \param input Input data
  * \param send_size The size of input data
  * \param output Output result
  */
Guolin Ke's avatar
Guolin Ke committed
118
  static void Allgather(char* input, comm_size_t send_size, char* output);
Guolin Ke's avatar
Guolin Ke committed
119
120

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
121
122
123
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(all_size)
  *        It can be used when nodes have different input size.
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
  * \param input Input data
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
128
  * \param all_size The size of output data
Guolin Ke's avatar
Guolin Ke committed
129
  */
Guolin Ke's avatar
Guolin Ke committed
130
  static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
Guolin Ke's avatar
Guolin Ke committed
131
132

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
133
134
  * \brief Perform reduce scatter by using recursive halving algorithm. 
           Communication times is O(log(n)), and communication cost is O(input_size)
Guolin Ke's avatar
Guolin Ke committed
135
136
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
137
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
138
139
140
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
141
  * \param output_size size of output data
Guolin Ke's avatar
Guolin Ke committed
142
143
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
144
145
146
  static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                            const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
147

Guolin Ke's avatar
Guolin Ke committed
148
149
150
151
152
153
  template<class T>
  static T GlobalSyncUpByMin(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
154
155
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 < *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

  template<class T>
  static T GlobalSyncUpByMax(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
178
179
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 > *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

Guolin Ke's avatar
Guolin Ke committed
196
197
private:
  /*! \brief Number of all machines */
198
  static THREAD_LOCAL int num_machines_;
Guolin Ke's avatar
Guolin Ke committed
199
  /*! \brief Rank of local machine */
200
  static THREAD_LOCAL int rank_;
Guolin Ke's avatar
Guolin Ke committed
201
  /*! \brief The network interface, provide send/recv functions  */
202
  static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
Guolin Ke's avatar
Guolin Ke committed
203
  /*! \brief Bruck map for all gather algorithm*/
204
  static THREAD_LOCAL BruckMap bruck_map_;
Guolin Ke's avatar
Guolin Ke committed
205
  /*! \brief Recursive halving map for reduce scatter */
206
  static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
Guolin Ke's avatar
Guolin Ke committed
207
  /*! \brief Buffer to store block start index */
Guolin Ke's avatar
Guolin Ke committed
208
  static THREAD_LOCAL std::vector<comm_size_t> block_start_;
Guolin Ke's avatar
Guolin Ke committed
209
  /*! \brief Buffer to store block size */
Guolin Ke's avatar
Guolin Ke committed
210
  static THREAD_LOCAL std::vector<comm_size_t> block_len_;
Guolin Ke's avatar
Guolin Ke committed
211
  /*! \brief Buffer  */
212
  static THREAD_LOCAL std::vector<char> buffer_;
Guolin Ke's avatar
Guolin Ke committed
213
  /*! \brief Size of buffer_ */
Guolin Ke's avatar
Guolin Ke committed
214
  static THREAD_LOCAL comm_size_t buffer_size_;
ww's avatar
ww committed
215
  /*! \brief Funcs*/
216
217
  static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
  static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
Guolin Ke's avatar
Guolin Ke committed
218
219
220
221
222
223
224
225
226
227
228
229
};

inline int Network::rank() {
  return rank_;
}

inline int Network::num_machines() {
  return num_machines_;
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
230
#endif   // LightGBM_NETWORK_H_