network.cpp 7.37 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
#include <LightGBM/network.h>

#include <LightGBM/utils/common.h>

#include "linkers.h"

#include <cstring>
#include <cstdlib>

namespace LightGBM {

Hui Xue's avatar
Hui Xue committed
12
// static member definition
13
14
THREAD_LOCAL int Network::num_machines_ = 1;
THREAD_LOCAL int Network::rank_ = 0;
15
16
17
18
19
20
21
THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_;
THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
THREAD_LOCAL std::vector<int> Network::block_start_;
THREAD_LOCAL std::vector<int>  Network::block_len_;
THREAD_LOCAL int Network::buffer_size_;
THREAD_LOCAL std::vector<char> Network::buffer_;
ww's avatar
ww committed
22
23
24
25
THREAD_LOCAL AllreduceFunction Network::AllreduceFuncPtr_ = NULL;
THREAD_LOCAL ReduceScatterFunction Network::ReduceScatterFuncPtr_ = NULL;
THREAD_LOCAL AllgatherFunction Network::AllgatherFuncPtr_ = NULL;

Guolin Ke's avatar
Guolin Ke committed
26
27

void Network::Init(NetworkConfig config) {
28
29
30
31
32
33
34
35
36
37
38
39
  if (config.num_machines > 1) {
    linkers_.reset(new Linkers(config));
    rank_ = linkers_->rank();
    num_machines_ = linkers_->num_machines();
    bruck_map_ = linkers_->bruck_map();
    recursive_halving_map_ = linkers_->recursive_halving_map();
    block_start_ = std::vector<int>(num_machines_);
    block_len_ = std::vector<int>(num_machines_);
    buffer_size_ = 1024 * 1024;
    buffer_.resize(buffer_size_);
    Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
  }
Guolin Ke's avatar
Guolin Ke committed
40
41
42
}

void Network::Dispose() {
43
44
45
  num_machines_ = 1;
  rank_ = 0;
  linkers_.reset(new Linkers());
Guolin Ke's avatar
Guolin Ke committed
46
47
48
}

void Network::Allreduce(char* input, int input_size, int type_size, char* output, const ReduceFunction& reducer) {
49
50
51
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
ww's avatar
ww committed
52
53
54
  if (AllreduceFuncPtr_ != NULL) {
    return AllreduceFuncPtr_(input, input_size, type_size, output, reducer);
  }
Guolin Ke's avatar
Guolin Ke committed
55
56
57
58
59
60
61
62
63
64
65
66
67
  int count = input_size / type_size;
  // if small package or small count , do it by all gather.(reduce the communication times.)
  if (count < num_machines_ || input_size < 4096) {
    AllreduceByAllGather(input, input_size, output, reducer);
    return;
  }
  // assign the blocks to every rank.
  int step = (count + num_machines_ - 1) / num_machines_;
  if (step < 1) {
    step = 1;
  }
  block_start_[0] = 0;
  for (int i = 0; i < num_machines_ - 1; ++i) {
Guolin Ke's avatar
Guolin Ke committed
68
    block_len_[i] = std::min(step * type_size, input_size - block_start_[i]);
Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
    block_start_[i + 1] = block_start_[i] + block_len_[i];
  }
  block_len_[num_machines_ - 1] = input_size - block_start_[num_machines_ - 1];
  // do reduce scatter
Guolin Ke's avatar
Guolin Ke committed
73
  ReduceScatter(input, input_size, block_start_.data(), block_len_.data(), output, reducer);
Guolin Ke's avatar
Guolin Ke committed
74
  // do all gather
Guolin Ke's avatar
Guolin Ke committed
75
  Allgather(output, input_size, block_start_.data(), block_len_.data(), output);
Guolin Ke's avatar
Guolin Ke committed
76
77
78
}

void Network::AllreduceByAllGather(char* input, int input_size, char* output, const ReduceFunction& reducer) {
79
80
81
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
88
89
90
91
92
  // assign blocks
  int all_size = input_size * num_machines_;
  block_start_[0] = 0;
  block_len_[0] = input_size;
  for (int i = 1; i < num_machines_; ++i) {
    block_start_[i] = block_start_[i - 1] + block_len_[i - 1];
    block_len_[i] = input_size;
  }
  // need use buffer here, since size of "output" is smaller than size after all gather
  if (input_size*num_machines_ > buffer_size_) {
    buffer_size_ = input_size*num_machines_;
Guolin Ke's avatar
Guolin Ke committed
93
    buffer_.resize(buffer_size_);
Guolin Ke's avatar
Guolin Ke committed
94
95
  }

Guolin Ke's avatar
Guolin Ke committed
96
  Allgather(input, all_size, block_start_.data(), block_len_.data(), buffer_.data());
Guolin Ke's avatar
Guolin Ke committed
97
  for (int i = 1; i < num_machines_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
98
    reducer(buffer_.data() + block_start_[i], buffer_.data() + block_start_[0], input_size);
Guolin Ke's avatar
Guolin Ke committed
99
100
  }
  // copy back
Guolin Ke's avatar
Guolin Ke committed
101
  std::memcpy(output, buffer_.data(), input_size);
Guolin Ke's avatar
Guolin Ke committed
102
103
104
}

void Network::Allgather(char* input, int send_size, char* output) {
105
106
107
108
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
  if (num_machines_ <= 1) { return; }
ww's avatar
ww committed
109
110
111
  if (AllgatherFuncPtr_ != NULL) {
    return AllgatherFuncPtr_(input, send_size, output);
  }
Guolin Ke's avatar
Guolin Ke committed
112
113
114
115
116
117
118
119
  // assign blocks
  block_start_[0] = 0;
  block_len_[0] = send_size;
  for (int i = 1; i < num_machines_; ++i) {
    block_start_[i] = block_start_[i - 1] + block_len_[i - 1];
    block_len_[i] = send_size;
  }
  // start all gather
Guolin Ke's avatar
Guolin Ke committed
120
  Allgather(input, send_size * num_machines_, block_start_.data(), block_len_.data(), output);
Guolin Ke's avatar
Guolin Ke committed
121
122
}

Guolin Ke's avatar
Guolin Ke committed
123
void Network::Allgather(char* input, int all_size, const int* block_start, const int* block_len, char* output) {
124
125
126
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
Guolin Ke's avatar
Guolin Ke committed
127
128
129
130
131
132
133
  int write_pos = 0;
  // use output as receive buffer
  std::memcpy(output, input, block_len[rank_]);
  write_pos += block_len[rank_];
  int accumulated_block = 1;
  for (int i = 0; i < bruck_map_.k; ++i) {
    // get current local block size
Guolin Ke's avatar
Guolin Ke committed
134
    int cur_block_size = std::min(1 << i, num_machines_ - accumulated_block);
Guolin Ke's avatar
Guolin Ke committed
135
136
137
138
    // get out rank
    int out_rank = bruck_map_.out_ranks[i];
    // get in rank
    int in_rank = bruck_map_.in_ranks[i];
Guolin Ke's avatar
Guolin Ke committed
139
140
    // get send information
    int need_send_len = 0;
Guolin Ke's avatar
Guolin Ke committed
141
    // get recv information
Guolin Ke's avatar
Guolin Ke committed
142
    int need_recv_len = 0;
Guolin Ke's avatar
Guolin Ke committed
143
    for (int j = 0; j < cur_block_size; ++j) {
Guolin Ke's avatar
Guolin Ke committed
144
145
      need_send_len += block_len[(rank_ + j) % num_machines_];
      need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_];
Guolin Ke's avatar
Guolin Ke committed
146
147
    }
    // send and recv at same time
Guolin Ke's avatar
Guolin Ke committed
148
149
    linkers_->SendRecv(out_rank, output, need_send_len, in_rank, output + write_pos, need_recv_len);
    write_pos += need_recv_len;
Guolin Ke's avatar
Guolin Ke committed
150
151
152
153
154
155
156
157
    accumulated_block += cur_block_size;
  }
  // rotate in-place
  std::reverse<char*>(output, output + all_size);
  std::reverse<char*>(output, output + block_start[rank_]);
  std::reverse<char*>(output + block_start[rank_], output + all_size);
}

ww's avatar
ww committed
158
void Network::ReduceScatter(char* input, int input_size, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) {
159
160
161
  if (num_machines_ <= 1) {
    Log::Fatal("Please initilize the network interface first");
  }
ww's avatar
ww committed
162
163
164
  if (ReduceScatterFuncPtr_ != NULL) {
    return ReduceScatterFuncPtr_(input, input_size, block_start, block_len, output, reducer);
  }
Guolin Ke's avatar
Guolin Ke committed
165
166
167
168
169
170
  if (recursive_halving_map_.need_pairwise) {
    for (int i = 1; i < num_machines_; ++i) {
      int out_rank = (rank_ + i) % num_machines_;
      int in_rank = (rank_ - i + num_machines_) % num_machines_;
      linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]);
      reducer(output, input + block_start[rank_], block_len[rank_]);
Guolin Ke's avatar
Guolin Ke committed
171
    }
Guolin Ke's avatar
Guolin Ke committed
172
  } else {
Guolin Ke's avatar
Guolin Ke committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    for (int i = 0; i < recursive_halving_map_.k; ++i) {
      // get target
      int target = recursive_halving_map_.ranks[i];
      int send_block_start = recursive_halving_map_.send_block_start[i];
      int recv_block_start = recursive_halving_map_.recv_block_start[i];
      // get send information
      int send_size = 0;
      for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
        send_size += block_len[send_block_start + j];
      }
      // get recv information
      int need_recv_cnt = 0;
      for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
        need_recv_cnt += block_len[recv_block_start + j];
      }
      // send and recv at same time
      linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt);
      // reduce
      reducer(output, input + block_start[recv_block_start], need_recv_cnt);
    }
  }
  // copy result
  std::memcpy(output, input + block_start[rank_], block_len[rank_]);
}

}  // namespace LightGBM