network.cpp 7.61 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
#include <LightGBM/network.h>

#include <LightGBM/utils/common.h>

#include "linkers.h"

#include <cstring>
#include <cstdlib>

namespace LightGBM {

Hui Xue's avatar
Hui Xue committed
12
// static member definition
13
14
THREAD_LOCAL int Network::num_machines_ = 1;
THREAD_LOCAL int Network::rank_ = 0;
15
16
17
THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_;
THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
Guolin Ke's avatar
Guolin Ke committed
18
19
20
THREAD_LOCAL std::vector<comm_size_t> Network::block_start_;
THREAD_LOCAL std::vector<comm_size_t>  Network::block_len_;
THREAD_LOCAL comm_size_t Network::buffer_size_;
21
THREAD_LOCAL std::vector<char> Network::buffer_;
22
23
THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = NULL;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = NULL;
ww's avatar
ww committed
24

Guolin Ke's avatar
Guolin Ke committed
25
26

void Network::Init(NetworkConfig config) {
27
28
29
30
31
32
  if (config.num_machines > 1) {
    linkers_.reset(new Linkers(config));
    rank_ = linkers_->rank();
    num_machines_ = linkers_->num_machines();
    bruck_map_ = linkers_->bruck_map();
    recursive_halving_map_ = linkers_->recursive_halving_map();
Guolin Ke's avatar
Guolin Ke committed
33
34
    block_start_ = std::vector<comm_size_t>(num_machines_);
    block_len_ = std::vector<comm_size_t>(num_machines_);
35
36
37
38
    buffer_size_ = 1024 * 1024;
    buffer_.resize(buffer_size_);
    Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
  }
Guolin Ke's avatar
Guolin Ke committed
39
40
41
}

void Network::Dispose() {
42
43
44
  num_machines_ = 1;
  rank_ = 0;
  linkers_.reset(new Linkers());
Guolin Ke's avatar
Guolin Ke committed
45
46
}

Guolin Ke's avatar
Guolin Ke committed
47
void Network::Allreduce(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
48
49
50
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
Guolin Ke's avatar
Guolin Ke committed
51
  comm_size_t count = input_size / type_size;
Guolin Ke's avatar
Guolin Ke committed
52
53
  // if small package or small count , do it by all gather.(reduce the communication times.)
  if (count < num_machines_ || input_size < 4096) {
Guolin Ke's avatar
Guolin Ke committed
54
    AllreduceByAllGather(input, input_size, type_size, output, reducer);
Guolin Ke's avatar
Guolin Ke committed
55
56
57
    return;
  }
  // assign the blocks to every rank.
Guolin Ke's avatar
Guolin Ke committed
58
  comm_size_t step = (count + num_machines_ - 1) / num_machines_;
Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
63
  if (step < 1) {
    step = 1;
  }
  block_start_[0] = 0;
  for (int i = 0; i < num_machines_ - 1; ++i) {
Guolin Ke's avatar
Guolin Ke committed
64
    block_len_[i] = std::min<comm_size_t>(step * type_size, input_size - block_start_[i]);
Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
    block_start_[i + 1] = block_start_[i] + block_len_[i];
  }
  block_len_[num_machines_ - 1] = input_size - block_start_[num_machines_ - 1];
  // do reduce scatter
Guolin Ke's avatar
Guolin Ke committed
69
  ReduceScatter(input, input_size, type_size, block_start_.data(), block_len_.data(), output, input_size, reducer);
Guolin Ke's avatar
Guolin Ke committed
70
  // do all gather
Guolin Ke's avatar
Guolin Ke committed
71
  Allgather(output, block_start_.data(), block_len_.data(), output, input_size);
Guolin Ke's avatar
Guolin Ke committed
72
73
}

Guolin Ke's avatar
Guolin Ke committed
74
void Network::AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
75
76
77
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
Guolin Ke's avatar
Guolin Ke committed
78
  // assign blocks
Guolin Ke's avatar
Guolin Ke committed
79
  comm_size_t all_size = input_size * num_machines_;
Guolin Ke's avatar
Guolin Ke committed
80
81
82
83
84
85
86
87
88
  block_start_[0] = 0;
  block_len_[0] = input_size;
  for (int i = 1; i < num_machines_; ++i) {
    block_start_[i] = block_start_[i - 1] + block_len_[i - 1];
    block_len_[i] = input_size;
  }
  // need use buffer here, since size of "output" is smaller than size after all gather
  if (input_size*num_machines_ > buffer_size_) {
    buffer_size_ = input_size*num_machines_;
Guolin Ke's avatar
Guolin Ke committed
89
    buffer_.resize(buffer_size_);
Guolin Ke's avatar
Guolin Ke committed
90
91
  }

Guolin Ke's avatar
Guolin Ke committed
92
  Allgather(input, block_start_.data(), block_len_.data(), buffer_.data(), all_size);
Guolin Ke's avatar
Guolin Ke committed
93
  for (int i = 1; i < num_machines_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
94
    reducer(buffer_.data() + block_start_[i], buffer_.data() + block_start_[0], type_size, input_size);
Guolin Ke's avatar
Guolin Ke committed
95
96
  }
  // copy back
Guolin Ke's avatar
Guolin Ke committed
97
  std::memcpy(output, buffer_.data(), input_size);
Guolin Ke's avatar
Guolin Ke committed
98
99
}

Guolin Ke's avatar
Guolin Ke committed
100
void Network::Allgather(char* input, comm_size_t send_size, char* output) {
101
102
103
104
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
  if (num_machines_ <= 1) { return; }
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
111
112
  // assign blocks
  block_start_[0] = 0;
  block_len_[0] = send_size;
  for (int i = 1; i < num_machines_; ++i) {
    block_start_[i] = block_start_[i - 1] + block_len_[i - 1];
    block_len_[i] = send_size;
  }
  // start all gather
Guolin Ke's avatar
Guolin Ke committed
113
  Allgather(input, block_start_.data(), block_len_.data(), output, send_size * num_machines_);
Guolin Ke's avatar
Guolin Ke committed
114
115
}

Guolin Ke's avatar
Guolin Ke committed
116
void Network::Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) {
117
118
119
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
120
  if (allgather_ext_fun_ != NULL) {
Guolin Ke's avatar
Guolin Ke committed
121
    return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
122
  }
Guolin Ke's avatar
Guolin Ke committed
123
  comm_size_t write_pos = 0;
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
128
129
  // use output as receive buffer
  std::memcpy(output, input, block_len[rank_]);
  write_pos += block_len[rank_];
  int accumulated_block = 1;
  for (int i = 0; i < bruck_map_.k; ++i) {
    // get current local block size
Guolin Ke's avatar
Guolin Ke committed
130
    int cur_block_size = std::min(1 << i, num_machines_ - accumulated_block);
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
    // get out rank
    int out_rank = bruck_map_.out_ranks[i];
    // get in rank
    int in_rank = bruck_map_.in_ranks[i];
Guolin Ke's avatar
Guolin Ke committed
135
    // get send information
Guolin Ke's avatar
Guolin Ke committed
136
    comm_size_t need_send_len = 0;
Guolin Ke's avatar
Guolin Ke committed
137
    // get recv information
Guolin Ke's avatar
Guolin Ke committed
138
    comm_size_t need_recv_len = 0;
Guolin Ke's avatar
Guolin Ke committed
139
    for (int j = 0; j < cur_block_size; ++j) {
Guolin Ke's avatar
Guolin Ke committed
140
141
      need_send_len += block_len[(rank_ + j) % num_machines_];
      need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_];
Guolin Ke's avatar
Guolin Ke committed
142
143
    }
    // send and recv at same time
Guolin Ke's avatar
Guolin Ke committed
144
145
    linkers_->SendRecv(out_rank, output, need_send_len, in_rank, output + write_pos, need_recv_len);
    write_pos += need_recv_len;
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
153
    accumulated_block += cur_block_size;
  }
  // rotate in-place
  std::reverse<char*>(output, output + all_size);
  std::reverse<char*>(output, output + block_start[rank_]);
  std::reverse<char*>(output + block_start[rank_], output + all_size);
}

Guolin Ke's avatar
Guolin Ke committed
154
void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
155
156
157
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
158
  if (reduce_scatter_ext_fun_ != NULL) {
Guolin Ke's avatar
Guolin Ke committed
159
    return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
ww's avatar
ww committed
160
  }
Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
165
  if (recursive_halving_map_.need_pairwise) {
    for (int i = 1; i < num_machines_; ++i) {
      int out_rank = (rank_ + i) % num_machines_;
      int in_rank = (rank_ - i + num_machines_) % num_machines_;
      linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]);
Guolin Ke's avatar
Guolin Ke committed
166
      reducer(output, input + block_start[rank_], type_size, block_len[rank_]);
Guolin Ke's avatar
Guolin Ke committed
167
    }
Guolin Ke's avatar
Guolin Ke committed
168
  } else {
Guolin Ke's avatar
Guolin Ke committed
169
170
171
    for (int i = 0; i < recursive_halving_map_.k; ++i) {
      // get target
      int target = recursive_halving_map_.ranks[i];
Guolin Ke's avatar
Guolin Ke committed
172
173
      comm_size_t send_block_start = recursive_halving_map_.send_block_start[i];
      comm_size_t recv_block_start = recursive_halving_map_.recv_block_start[i];
Guolin Ke's avatar
Guolin Ke committed
174
      // get send information
Guolin Ke's avatar
Guolin Ke committed
175
      comm_size_t send_size = 0;
Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
      for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
        send_size += block_len[send_block_start + j];
      }
      // get recv information
Guolin Ke's avatar
Guolin Ke committed
180
      comm_size_t need_recv_cnt = 0;
Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
185
186
      for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
        need_recv_cnt += block_len[recv_block_start + j];
      }
      // send and recv at same time
      linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt);
      // reduce
Guolin Ke's avatar
Guolin Ke committed
187
      reducer(output, input + block_start[recv_block_start], type_size, need_recv_cnt);
Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
192
193
194
    }
  }
  // copy result
  std::memcpy(output, input + block_start[rank_], block_len[rank_]);
}

}  // namespace LightGBM