network.h 6.37 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#ifndef LIGHTGBM_NETWORK_H_
#define LIGHTGBM_NETWORK_H_

#include <LightGBM/utils/log.h>

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <functional>
#include <vector>

namespace LightGBM {

/*! \brief forward declaration */
class Linkers;

Qiwei Ye's avatar
Qiwei Ye committed
17
/*! \brief The network structure for all_gather */
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class BruckMap {
public:
  /*! \brief The communication times for one all gather operation */
  int k;
  /*! \brief in_ranks[i] means the incomming rank on i-th communication */
  std::vector<int> in_ranks;
  /*! \brief out_ranks[i] means the out rank on i-th communication */
  std::vector<int> out_ranks;
  BruckMap();
  explicit BruckMap(int n);
  /*!
  * \brief Create the object of bruck map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of bruck map
  */
  static BruckMap Construct(int rank, int num_machines);
};


/*!
* \brief node type on recursive halving algorithm
Qiwei Ye's avatar
Qiwei Ye committed
40
41
42
43
44
*        When number of machines is not power of 2, need group machines into power of 2 group.
*        And we can let each group has at most 2 machines.
*        if the group only has 1 machine. this machine is the normal node
*        if the grou has 2 machines, this group will have two type of nodes, one is the leader.
*        leader will represent this group and communication with others.
Guolin Ke's avatar
Guolin Ke committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
*/
enum RecursiveHalvingNodeType {
  Normal,  // normal node, 1 group only have 1 machine
  GroupLeader,  // leader of group when number of machines in this group is 2.
  Other  // non-leader machines in group
};

/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
public:
  /*! \brief Communication times for one recursize halving algorithm  */
  int k;
  /*! \brief Node type */
  RecursiveHalvingNodeType type;
  /*! \brief Neighbor, only used for non-normal node*/
  int neighbor;
  /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
  std::vector<int> ranks;
  /*! \brief  send_block_start[i] means send block start index at i-th communication*/
  std::vector<int> send_block_start;
  /*! \brief  send_block_start[i] means send block size at i-th communication*/
  std::vector<int> send_block_len;
  /*! \brief  send_block_start[i] means recv block start index at i-th communication*/
  std::vector<int> recv_block_start;
  /*! \brief  send_block_start[i] means recv block size  at i-th communication*/
  std::vector<int> recv_block_len;

  RecursiveHalvingMap();

  RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n);

  /*!
  * \brief Create the object of recursive halving map
  * \param rank Rank of this machine
  * \param num_machines The total number of machines
  * \return The object of recursive halving map
  */
  static RecursiveHalvingMap Construct(int rank, int num_machines);
};

/*! \brief A static class that contains some collective communication algorithm */
class Network {
public:
  /*!
  * \brief Initialize
  * \param config Config of network setting
  */
  static void Init(NetworkConfig config);
  /*! \brief Free this static class */
  static void Dispose();
  /*! \brief Get rank of this machine */
  static inline int rank();
  /*! \brief Get total number of machines */
  static inline int num_machines();

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
101
  * \brief Perform all_reduce. if data size is small,
Guolin Ke's avatar
Guolin Ke committed
102
103
104
105
106
107
108
109
110
111
112
           will perform AllreduceByAllGather, else with call ReduceScatter followed allgather
  * \param input Input data
  * \param input_size The size of input data
  * \param type_size The size of one object in the reduce function
  * \param output Output result
  * \param reducer Reduce function
  */
  static void Allreduce(char* input, int input_size, int type_size,
    char* output, const ReduceFunction& reducer);

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
113
  * \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
Guolin Ke's avatar
Guolin Ke committed
114
115
116
117
118
119
120
121
122
  * \param input Input data
  * \param input_size The size of input data
  * \param output Output result
  * \param reducer Reduce function
  */
  static void AllreduceByAllGather(char* input, int input_size, char* output,
    const ReduceFunction& reducer);

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
123
124
125
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(send_size * number_machine)
  *        It can be used when all nodes have same input size.
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
132
  * \param input Input data
  * \param send_size The size of input data
  * \param output Output result
  */
  static void Allgather(char* input, int send_size, char* output);

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
133
134
135
  * \brief Performing all_gather by using bruck algorithm. 
           Communication times is O(log(n)), and communication cost is O(all_size)
  *        It can be used when nodes have different input size.
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
140
141
142
143
144
145
  * \param input Input data
  * \param all_size The size of input data
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
  */
  static void Allgather(char* input, int all_size, int* block_start,
    int* block_len, char* output);

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
146
147
  * \brief Perform reduce scatter by using recursive halving algorithm. 
           Communication times is O(log(n)), and communication cost is O(input_size)
Guolin Ke's avatar
Guolin Ke committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
  * \param input Input data
  * \param input_size The size of input data
  * \param block_start The block start for different machines
  * \param block_len The block size for different machines
  * \param output Output result
  * \param reducer Reduce function
  */
  static void ReduceScatter(char* input, int input_size,
    int* block_start, int* block_len, char* output,
    const ReduceFunction& reducer);

private:
  /*! \brief Number of all machines */
  static int num_machines_;
  /*! \brief Rank of local machine */
  static int rank_;
  /*! \brief The network interface, provide send/recv functions  */
  static Linkers *linkers_;
  /*! \brief Bruck map for all gather algorithm*/
  static BruckMap bruck_map_;
  /*! \brief Recursive halving map for reduce scatter */
  static RecursiveHalvingMap recursive_halving_map_;
  /*! \brief Buffer to store block start index */
  static int* block_start_;
  /*! \brief Buffer to store block size */
  static int* block_len_;
  /*! \brief Buffer  */
  static char* buffer_;
  /*! \brief Size of buffer_ */
  static int buffer_size_;
};

inline int Network::rank() {
  return rank_;
}

inline int Network::num_machines() {
  return num_machines_;
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
190
#endif   // LightGBM_NETWORK_H_