network.cpp 10.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
#include <LightGBM/network.h>

#include <LightGBM/utils/common.h>

#include "linkers.h"

#include <cstring>
#include <cstdlib>

namespace LightGBM {

Hui Xue's avatar
Hui Xue committed
12
// static member definition
13
14
THREAD_LOCAL int Network::num_machines_ = 1;
THREAD_LOCAL int Network::rank_ = 0;
15
16
17
THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_;
THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
Guolin Ke's avatar
Guolin Ke committed
18
19
THREAD_LOCAL std::vector<comm_size_t> Network::block_start_;
THREAD_LOCAL std::vector<comm_size_t>  Network::block_len_;
20
THREAD_LOCAL comm_size_t Network::buffer_size_ = 0;
21
THREAD_LOCAL std::vector<char> Network::buffer_;
22
23
THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = nullptr;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = nullptr;
ww's avatar
ww committed
24

Guolin Ke's avatar
Guolin Ke committed
25
26

void Network::Init(NetworkConfig config) {
27
28
29
30
31
32
  if (config.num_machines > 1) {
    linkers_.reset(new Linkers(config));
    rank_ = linkers_->rank();
    num_machines_ = linkers_->num_machines();
    bruck_map_ = linkers_->bruck_map();
    recursive_halving_map_ = linkers_->recursive_halving_map();
Guolin Ke's avatar
Guolin Ke committed
33
34
    block_start_ = std::vector<comm_size_t>(num_machines_);
    block_len_ = std::vector<comm_size_t>(num_machines_);
35
36
37
38
    buffer_size_ = 1024 * 1024;
    buffer_.resize(buffer_size_);
    Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
  }
Guolin Ke's avatar
Guolin Ke committed
39
40
}

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
void Network::Init(int num_machines, int rank,
                   ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun) {
  if (num_machines > 1) {
    rank_ = rank;
    num_machines_ = num_machines;
    block_start_ = std::vector<comm_size_t>(num_machines_);
    block_len_ = std::vector<comm_size_t>(num_machines_);
    buffer_size_ = 1024 * 1024;
    buffer_.resize(buffer_size_);
    reduce_scatter_ext_fun_ = reduce_scatter_ext_fun;
    allgather_ext_fun_ = allgather_ext_fun;
    Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
  }
}

Guolin Ke's avatar
Guolin Ke committed
56
void Network::Dispose() {
57
58
59
  num_machines_ = 1;
  rank_ = 0;
  linkers_.reset(new Linkers());
60
61
  reduce_scatter_ext_fun_ = nullptr;
  allgather_ext_fun_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
62
63
}

Guolin Ke's avatar
Guolin Ke committed
64
void Network::Allreduce(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
65
66
67
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
Guolin Ke's avatar
Guolin Ke committed
68
  comm_size_t count = input_size / type_size;
Guolin Ke's avatar
Guolin Ke committed
69
70
  // if small package or small count , do it by all gather.(reduce the communication times.)
  if (count < num_machines_ || input_size < 4096) {
Guolin Ke's avatar
Guolin Ke committed
71
    AllreduceByAllGather(input, input_size, type_size, output, reducer);
Guolin Ke's avatar
Guolin Ke committed
72
73
74
    return;
  }
  // assign the blocks to every rank.
Guolin Ke's avatar
Guolin Ke committed
75
  comm_size_t step = (count + num_machines_ - 1) / num_machines_;
Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
80
  if (step < 1) {
    step = 1;
  }
  block_start_[0] = 0;
  for (int i = 0; i < num_machines_ - 1; ++i) {
Guolin Ke's avatar
Guolin Ke committed
81
    block_len_[i] = std::min<comm_size_t>(step * type_size, input_size - block_start_[i]);
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
    block_start_[i + 1] = block_start_[i] + block_len_[i];
  }
  block_len_[num_machines_ - 1] = input_size - block_start_[num_machines_ - 1];
  // do reduce scatter
Guolin Ke's avatar
Guolin Ke committed
86
  ReduceScatter(input, input_size, type_size, block_start_.data(), block_len_.data(), output, input_size, reducer);
Guolin Ke's avatar
Guolin Ke committed
87
  // do all gather
Guolin Ke's avatar
Guolin Ke committed
88
  Allgather(output, block_start_.data(), block_len_.data(), output, input_size);
Guolin Ke's avatar
Guolin Ke committed
89
90
}

Guolin Ke's avatar
Guolin Ke committed
91
void Network::AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
92
93
94
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
Guolin Ke's avatar
Guolin Ke committed
95
  // assign blocks
Guolin Ke's avatar
Guolin Ke committed
96
  comm_size_t all_size = input_size * num_machines_;
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
101
102
103
104
105
  block_start_[0] = 0;
  block_len_[0] = input_size;
  for (int i = 1; i < num_machines_; ++i) {
    block_start_[i] = block_start_[i - 1] + block_len_[i - 1];
    block_len_[i] = input_size;
  }
  // need use buffer here, since size of "output" is smaller than size after all gather
  if (input_size*num_machines_ > buffer_size_) {
    buffer_size_ = input_size*num_machines_;
Guolin Ke's avatar
Guolin Ke committed
106
    buffer_.resize(buffer_size_);
Guolin Ke's avatar
Guolin Ke committed
107
108
  }

Guolin Ke's avatar
Guolin Ke committed
109
  Allgather(input, block_start_.data(), block_len_.data(), buffer_.data(), all_size);
Guolin Ke's avatar
Guolin Ke committed
110
  for (int i = 1; i < num_machines_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
111
    reducer(buffer_.data() + block_start_[i], buffer_.data() + block_start_[0], type_size, input_size);
Guolin Ke's avatar
Guolin Ke committed
112
113
  }
  // copy back
Guolin Ke's avatar
Guolin Ke committed
114
  std::memcpy(output, buffer_.data(), input_size);
Guolin Ke's avatar
Guolin Ke committed
115
116
}

Guolin Ke's avatar
Guolin Ke committed
117
void Network::Allgather(char* input, comm_size_t send_size, char* output) {
118
119
120
121
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
  if (num_machines_ <= 1) { return; }
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
126
127
128
129
  // assign blocks
  block_start_[0] = 0;
  block_len_[0] = send_size;
  for (int i = 1; i < num_machines_; ++i) {
    block_start_[i] = block_start_[i - 1] + block_len_[i - 1];
    block_len_[i] = send_size;
  }
  // start all gather
Guolin Ke's avatar
Guolin Ke committed
130
  Allgather(input, block_start_.data(), block_len_.data(), output, send_size * num_machines_);
Guolin Ke's avatar
Guolin Ke committed
131
132
}

Guolin Ke's avatar
Guolin Ke committed
133
void Network::Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) {
134
135
136
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
137
  if (allgather_ext_fun_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
138
    return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
139
  }
Guolin Ke's avatar
Guolin Ke committed
140
141
142
143
144
145
146
147
148
149
150
151
152
  const comm_size_t kRecursiveDoublingThreshold = 1024 * 1024; // 1MB
  const comm_size_t kBruckThreshold = 512 * 1024; // 512KB
  const bool is_power_of2 = (num_machines_ & (num_machines_ - 1)) == 0;
  if (is_power_of2 && all_size < kRecursiveDoublingThreshold) {
    AllgatherRecursiveDoubling(input, block_start, block_len, output, all_size);
  } else if (all_size < kBruckThreshold) {
    AllgatherBruck(input, block_start, block_len, output, all_size);
  } else {
    AllgatherRing(input, block_start, block_len, output, all_size);
  }
}

void Network::AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) {
Guolin Ke's avatar
Guolin Ke committed
153
  comm_size_t write_pos = 0;
Guolin Ke's avatar
Guolin Ke committed
154
155
156
157
158
159
  // use output as receive buffer
  std::memcpy(output, input, block_len[rank_]);
  write_pos += block_len[rank_];
  int accumulated_block = 1;
  for (int i = 0; i < bruck_map_.k; ++i) {
    // get current local block size
Guolin Ke's avatar
Guolin Ke committed
160
    int cur_block_size = std::min(1 << i, num_machines_ - accumulated_block);
Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
    // get out rank
    int out_rank = bruck_map_.out_ranks[i];
    // get in rank
    int in_rank = bruck_map_.in_ranks[i];
Guolin Ke's avatar
Guolin Ke committed
165
    // get send information
Guolin Ke's avatar
Guolin Ke committed
166
    comm_size_t need_send_len = 0;
Guolin Ke's avatar
Guolin Ke committed
167
    // get recv information
Guolin Ke's avatar
Guolin Ke committed
168
    comm_size_t need_recv_len = 0;
Guolin Ke's avatar
Guolin Ke committed
169
    for (int j = 0; j < cur_block_size; ++j) {
Guolin Ke's avatar
Guolin Ke committed
170
171
      need_send_len += block_len[(rank_ + j) % num_machines_];
      need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_];
Guolin Ke's avatar
Guolin Ke committed
172
173
    }
    // send and recv at same time
Guolin Ke's avatar
Guolin Ke committed
174
175
    linkers_->SendRecv(out_rank, output, need_send_len, in_rank, output + write_pos, need_recv_len);
    write_pos += need_recv_len;
Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
180
181
182
183
    accumulated_block += cur_block_size;
  }
  // rotate in-place
  std::reverse<char*>(output, output + all_size);
  std::reverse<char*>(output, output + block_start[rank_]);
  std::reverse<char*>(output + block_start[rank_], output + all_size);
}

Guolin Ke's avatar
Guolin Ke committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
void Network::AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t) {
  // use output as receive buffer
  std::memcpy(output + block_start[rank_], input, block_len[rank_]);
  for (int i = 0; i < bruck_map_.k; ++i) {
    // get current local block size
    int cur_step = 1 << i;
    const int vgroup = rank_ / cur_step;
    const int vrank = vgroup * cur_step;
    int target = rank_ + cur_step;
    int target_vrank = (vgroup + 1) * cur_step;
    if (vgroup & 1) {
      target = rank_ - cur_step;
      target_vrank = (vgroup - 1) * cur_step;
    }
    // get send information
    comm_size_t need_send_len = 0;
    // get recv information
    comm_size_t need_recv_len = 0;
    for (int j = 0; j < cur_step; ++j) {
      need_send_len += block_len[(vrank + j)];
      need_recv_len += block_len[(target_vrank + j)];
    }
    // send and recv at same time
    linkers_->SendRecv(target, output + block_start[vrank], need_send_len, 
                       target, output + block_start[target_vrank], need_recv_len);
  }
}

void Network::AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t) {
  // use output as receive buffer
  std::memcpy(output + block_start[rank_], input, block_len[rank_]);
  int out_rank = (rank_ + 1) % num_machines_;
  int in_rank = (rank_ - 1 + num_machines_) % num_machines_;
  int out_place = rank_;
  int in_place = in_rank;
  for (int i = 1; i < num_machines_; ++i) {
    // send and recv at same time
    linkers_->SendRecv(out_rank, output + block_start[out_place], block_len[out_place],
                       in_rank, output + block_start[in_place], block_len[in_place]);
    out_place = (out_place - 1 + num_machines_) % num_machines_;
    in_place = (in_place - 1 + num_machines_) % num_machines_;
  }
}

Guolin Ke's avatar
Guolin Ke committed
228
void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
229
230
231
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
232
  if (reduce_scatter_ext_fun_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
233
    return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
ww's avatar
ww committed
234
  }
Guolin Ke's avatar
Guolin Ke committed
235
236
237
238
239
  if (recursive_halving_map_.need_pairwise) {
    for (int i = 1; i < num_machines_; ++i) {
      int out_rank = (rank_ + i) % num_machines_;
      int in_rank = (rank_ - i + num_machines_) % num_machines_;
      linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]);
Guolin Ke's avatar
Guolin Ke committed
240
      reducer(output, input + block_start[rank_], type_size, block_len[rank_]);
Guolin Ke's avatar
Guolin Ke committed
241
    }
Guolin Ke's avatar
Guolin Ke committed
242
  } else {
Guolin Ke's avatar
Guolin Ke committed
243
244
245
    for (int i = 0; i < recursive_halving_map_.k; ++i) {
      // get target
      int target = recursive_halving_map_.ranks[i];
Guolin Ke's avatar
Guolin Ke committed
246
247
      comm_size_t send_block_start = recursive_halving_map_.send_block_start[i];
      comm_size_t recv_block_start = recursive_halving_map_.recv_block_start[i];
Guolin Ke's avatar
Guolin Ke committed
248
      // get send information
Guolin Ke's avatar
Guolin Ke committed
249
      comm_size_t send_size = 0;
Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
      for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
        send_size += block_len[send_block_start + j];
      }
      // get recv information
Guolin Ke's avatar
Guolin Ke committed
254
      comm_size_t need_recv_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
255
256
257
258
259
260
      for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
        need_recv_cnt += block_len[recv_block_start + j];
      }
      // send and recv at same time
      linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt);
      // reduce
Guolin Ke's avatar
Guolin Ke committed
261
      reducer(output, input + block_start[recv_block_start], type_size, need_recv_cnt);
Guolin Ke's avatar
Guolin Ke committed
262
263
264
265
266
267
268
    }
  }
  // copy result
  std::memcpy(output, input + block_start[rank_], block_len[rank_]);
}

}  // namespace LightGBM