network.h 8.14 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
#ifndef LIGHTGBM_NETWORK_H_
#define LIGHTGBM_NETWORK_H_

#include <LightGBM/utils/log.h>

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <functional>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
11
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
12
13
14
15
16
17

namespace LightGBM {

/*! \brief forward declaration */
class Linkers;

Qiwei Ye's avatar
Qiwei Ye committed
18
/*! \brief The network structure for all_gather */
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class BruckMap {
public:
  /*! \brief The communication times for one all gather operation */
  int k;
  /*! \brief in_ranks[i] means the incomming rank on i-th communication */
  std::vector<int> in_ranks;
  /*! \brief out_ranks[i] means the out rank on i-th communication */
  std::vector<int> out_ranks;
  BruckMap();
  explicit BruckMap(int n);
  /*!
  * \brief Create the object of bruck map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of bruck map
  */
  static BruckMap Construct(int rank, int num_machines);
};

/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
public:
Guolin Ke's avatar
Guolin Ke committed
41
  bool need_pairwise;
Guolin Ke's avatar
Guolin Ke committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  /*! \brief Communication times for one recursize halving algorithm  */
  int k;
  /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
  std::vector<int> ranks;
  /*! \brief  send_block_start[i] means send block start index at i-th communication*/
  std::vector<int> send_block_start;
  /*! \brief  send_block_start[i] means send block size at i-th communication*/
  std::vector<int> send_block_len;
  /*! \brief  send_block_start[i] means recv block start index at i-th communication*/
  std::vector<int> recv_block_start;
  /*! \brief  send_block_start[i] means recv block size  at i-th communication*/
  std::vector<int> recv_block_len;

  RecursiveHalvingMap();

Guolin Ke's avatar
Guolin Ke committed
57
  RecursiveHalvingMap(int k, bool in_need_pairwise);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

  /*!
  * \brief Create the object of recursive halving map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of recursive halving map
  */
  static RecursiveHalvingMap Construct(int rank, int num_machines);
};

/*! \brief A static class that contains some collective communication algorithm */
class Network {
public:
  /*!
  * \brief Initialize
  * \param config Config of network setting
  */
  static void Init(NetworkConfig config);
  /*! \brief Free this static class */
  static void Dispose();
  /*! \brief Get rank of this machine */
  static inline int rank();
  /*! \brief Get total number of machines */
  static inline int num_machines();

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
84
  * \brief Perform all_reduce. if data size is small,
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
90
91
           will perform AllreduceByAllGather, else with call ReduceScatter followed allgather
  * \param input Input data
  * \param input_size The size of input data
  * \param type_size The size of one object in the reduce function
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
92
93
  static void Allreduce(char* input, comm_size_t input_size, int type_size,
                        char* output, const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
94
95

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
96
  * \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
Guolin Ke's avatar
Guolin Ke committed
97
98
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
99
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
100
101
102
  * \param output Output result
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
103
104
  static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
                                   const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
105
106

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
107
108
109
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(send_size * number_machine)
  *        It can be used when all nodes have same input size.
Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
  * \param input Input data
  * \param send_size The size of input data
  * \param output Output result
  */
Guolin Ke's avatar
Guolin Ke committed
114
  static void Allgather(char* input, comm_size_t send_size, char* output);
Guolin Ke's avatar
Guolin Ke committed
115
116

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
117
118
119
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(all_size)
  *        It can be used when nodes have different input size.
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
  * \param input Input data
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
124
  * \param all_size The size of output data
Guolin Ke's avatar
Guolin Ke committed
125
  */
Guolin Ke's avatar
Guolin Ke committed
126
  static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
Guolin Ke's avatar
Guolin Ke committed
127
128

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
129
130
  * \brief Perform reduce scatter by using recursive halving algorithm. 
           Communication times is O(log(n)), and communication cost is O(input_size)
Guolin Ke's avatar
Guolin Ke committed
131
132
  * \param input Input data
  * \param input_size The size of input data
Guolin Ke's avatar
Guolin Ke committed
133
  * \param type_size The size of one object in the reduce function
Guolin Ke's avatar
Guolin Ke committed
134
135
136
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
Guolin Ke's avatar
Guolin Ke committed
137
  * \param output_size size of output data
Guolin Ke's avatar
Guolin Ke committed
138
139
  * \param reducer Reduce function
  */
Guolin Ke's avatar
Guolin Ke committed
140
141
142
  static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
                            const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
                            const ReduceFunction& reducer);
Guolin Ke's avatar
Guolin Ke committed
143

Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
148
149
  template<class T>
  static T GlobalSyncUpByMin(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
150
151
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 < *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }

  template<class T>
  static T GlobalSyncUpByMax(T& local) {
    T global = local;
    Allreduce(reinterpret_cast<char*>(&local),
              sizeof(local), sizeof(local),
              reinterpret_cast<char*>(&global),
Guolin Ke's avatar
Guolin Ke committed
174
175
              [] (const char* src, char* dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
      const T *p1;
      T *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const T *>(src);
        p2 = reinterpret_cast<T *>(dst);
        if (*p1 > *p2) {
          std::memcpy(dst, src, type_size);
        }
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    return global;
  }
ww's avatar
ww committed
191
192
193
  /*! \brief set variables and function ptrs */
  static void SetRank(int rank) { rank_ = rank;}
  static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
194
195
  static void SetReduceScatterFunction(ReduceScatterFunction reduce_scatter_ext_fun) { reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; }
  static void SetAllgatherFunction(AllgatherFunction allgather_ext_fun) { allgather_ext_fun_ = allgather_ext_fun; }
Guolin Ke's avatar
Guolin Ke committed
196

Guolin Ke's avatar
Guolin Ke committed
197
198
private:
  /*! \brief Number of all machines */
199
  static THREAD_LOCAL int num_machines_;
Guolin Ke's avatar
Guolin Ke committed
200
  /*! \brief Rank of local machine */
201
  static THREAD_LOCAL int rank_;
Guolin Ke's avatar
Guolin Ke committed
202
  /*! \brief The network interface, provide send/recv functions  */
203
  static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
Guolin Ke's avatar
Guolin Ke committed
204
  /*! \brief Bruck map for all gather algorithm*/
205
  static THREAD_LOCAL BruckMap bruck_map_;
Guolin Ke's avatar
Guolin Ke committed
206
  /*! \brief Recursive halving map for reduce scatter */
207
  static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
Guolin Ke's avatar
Guolin Ke committed
208
  /*! \brief Buffer to store block start index */
Guolin Ke's avatar
Guolin Ke committed
209
  static THREAD_LOCAL std::vector<comm_size_t> block_start_;
Guolin Ke's avatar
Guolin Ke committed
210
  /*! \brief Buffer to store block size */
Guolin Ke's avatar
Guolin Ke committed
211
  static THREAD_LOCAL std::vector<comm_size_t> block_len_;
Guolin Ke's avatar
Guolin Ke committed
212
  /*! \brief Buffer  */
213
  static THREAD_LOCAL std::vector<char> buffer_;
Guolin Ke's avatar
Guolin Ke committed
214
  /*! \brief Size of buffer_ */
Guolin Ke's avatar
Guolin Ke committed
215
  static THREAD_LOCAL comm_size_t buffer_size_;
ww's avatar
ww committed
216
  /*! \brief Funcs*/
217
218
  static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
  static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
223
224
225
226
227
228
229
230
};

inline int Network::rank() {
  return rank_;
}

inline int Network::num_machines() {
  return num_machines_;
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
231
#endif   // LightGBM_NETWORK_H_