#include #include #include "linkers.h" #include #include namespace LightGBM { // static member definition THREAD_LOCAL int Network::num_machines_ = 1; THREAD_LOCAL int Network::rank_ = 0; THREAD_LOCAL std::unique_ptr Network::linkers_; THREAD_LOCAL BruckMap Network::bruck_map_; THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_; THREAD_LOCAL std::vector Network::block_start_; THREAD_LOCAL std::vector Network::block_len_; THREAD_LOCAL comm_size_t Network::buffer_size_ = 0; THREAD_LOCAL std::vector Network::buffer_; THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = nullptr; THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = nullptr; void Network::Init(NetworkConfig config) { if (config.num_machines > 1) { linkers_.reset(new Linkers(config)); rank_ = linkers_->rank(); num_machines_ = linkers_->num_machines(); bruck_map_ = linkers_->bruck_map(); recursive_halving_map_ = linkers_->recursive_halving_map(); block_start_ = std::vector(num_machines_); block_len_ = std::vector(num_machines_); buffer_size_ = 1024 * 1024; buffer_.resize(buffer_size_); Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_); } } void Network::Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun) { if (num_machines > 1) { rank_ = rank; num_machines_ = num_machines; block_start_ = std::vector(num_machines_); block_len_ = std::vector(num_machines_); buffer_size_ = 1024 * 1024; buffer_.resize(buffer_size_); reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; allgather_ext_fun_ = allgather_ext_fun; Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_); } } void Network::Dispose() { num_machines_ = 1; rank_ = 0; linkers_.reset(new Linkers()); reduce_scatter_ext_fun_ = nullptr; allgather_ext_fun_ = nullptr; } void Network::Allreduce(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) { if (num_machines_ <= 1) { Log::Fatal("Please initilize the network interface first"); } comm_size_t count = input_size / type_size; // if small package or small count , do it by all gather.(reduce the communication times.) if (count < num_machines_ || input_size < 4096) { AllreduceByAllGather(input, input_size, type_size, output, reducer); return; } // assign the blocks to every rank. comm_size_t step = (count + num_machines_ - 1) / num_machines_; if (step < 1) { step = 1; } block_start_[0] = 0; for (int i = 0; i < num_machines_ - 1; ++i) { block_len_[i] = std::min(step * type_size, input_size - block_start_[i]); block_start_[i + 1] = block_start_[i] + block_len_[i]; } block_len_[num_machines_ - 1] = input_size - block_start_[num_machines_ - 1]; // do reduce scatter ReduceScatter(input, input_size, type_size, block_start_.data(), block_len_.data(), output, input_size, reducer); // do all gather Allgather(output, block_start_.data(), block_len_.data(), output, input_size); } void Network::AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) { if (num_machines_ <= 1) { Log::Fatal("Please initilize the network interface first"); } // assign blocks comm_size_t all_size = input_size * num_machines_; block_start_[0] = 0; block_len_[0] = input_size; for (int i = 1; i < num_machines_; ++i) { block_start_[i] = block_start_[i - 1] + block_len_[i - 1]; block_len_[i] = input_size; } // need use buffer here, since size of "output" is smaller than size after all gather if (input_size*num_machines_ > buffer_size_) { buffer_size_ = input_size*num_machines_; buffer_.resize(buffer_size_); } Allgather(input, block_start_.data(), block_len_.data(), buffer_.data(), all_size); for (int i = 1; i < num_machines_; ++i) { reducer(buffer_.data() + block_start_[i], buffer_.data() + block_start_[0], type_size, input_size); } // copy back std::memcpy(output, buffer_.data(), input_size); } void Network::Allgather(char* input, comm_size_t send_size, char* output) { if (num_machines_ <= 1) { Log::Fatal("Please initilize the network interface first"); } if (num_machines_ <= 1) { return; } // assign blocks block_start_[0] = 0; block_len_[0] = send_size; for (int i = 1; i < num_machines_; ++i) { block_start_[i] = block_start_[i - 1] + block_len_[i - 1]; block_len_[i] = send_size; } // start all gather Allgather(input, block_start_.data(), block_len_.data(), output, send_size * num_machines_); } void Network::Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) { if (num_machines_ <= 1) { Log::Fatal("Please initilize the network interface first"); } if (allgather_ext_fun_ != nullptr) { return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size); } const comm_size_t kRecursiveDoublingThreshold = 1024 * 1024; // 1MB const comm_size_t kBruckThreshold = 512 * 1024; // 512KB const bool is_power_of2 = (num_machines_ & (num_machines_ - 1)) == 0; if (is_power_of2 && all_size < kRecursiveDoublingThreshold) { AllgatherRecursiveDoubling(input, block_start, block_len, output, all_size); } else if (all_size < kBruckThreshold) { AllgatherBruck(input, block_start, block_len, output, all_size); } else { AllgatherRing(input, block_start, block_len, output, all_size); } } void Network::AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) { comm_size_t write_pos = 0; // use output as receive buffer std::memcpy(output, input, block_len[rank_]); write_pos += block_len[rank_]; int accumulated_block = 1; for (int i = 0; i < bruck_map_.k; ++i) { // get current local block size int cur_block_size = std::min(1 << i, num_machines_ - accumulated_block); // get out rank int out_rank = bruck_map_.out_ranks[i]; // get in rank int in_rank = bruck_map_.in_ranks[i]; // get send information comm_size_t need_send_len = 0; // get recv information comm_size_t need_recv_len = 0; for (int j = 0; j < cur_block_size; ++j) { need_send_len += block_len[(rank_ + j) % num_machines_]; need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_]; } // send and recv at same time linkers_->SendRecv(out_rank, output, need_send_len, in_rank, output + write_pos, need_recv_len); write_pos += need_recv_len; accumulated_block += cur_block_size; } // rotate in-place std::reverse(output, output + all_size); std::reverse(output, output + block_start[rank_]); std::reverse(output + block_start[rank_], output + all_size); } void Network::AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t) { // use output as receive buffer std::memcpy(output + block_start[rank_], input, block_len[rank_]); for (int i = 0; i < bruck_map_.k; ++i) { // get current local block size int cur_step = 1 << i; const int vgroup = rank_ / cur_step; const int vrank = vgroup * cur_step; int target = rank_ + cur_step; int target_vrank = (vgroup + 1) * cur_step; if (vgroup & 1) { target = rank_ - cur_step; target_vrank = (vgroup - 1) * cur_step; } // get send information comm_size_t need_send_len = 0; // get recv information comm_size_t need_recv_len = 0; for (int j = 0; j < cur_step; ++j) { need_send_len += block_len[(vrank + j)]; need_recv_len += block_len[(target_vrank + j)]; } // send and recv at same time linkers_->SendRecv(target, output + block_start[vrank], need_send_len, target, output + block_start[target_vrank], need_recv_len); } } void Network::AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t) { // use output as receive buffer std::memcpy(output + block_start[rank_], input, block_len[rank_]); int out_rank = (rank_ + 1) % num_machines_; int in_rank = (rank_ - 1 + num_machines_) % num_machines_; int out_place = rank_; int in_place = in_rank; for (int i = 1; i < num_machines_; ++i) { // send and recv at same time linkers_->SendRecv(out_rank, output + block_start[out_place], block_len[out_place], in_rank, output + block_start[in_place], block_len[in_place]); out_place = (out_place - 1 + num_machines_) % num_machines_; in_place = (in_place - 1 + num_machines_) % num_machines_; } } void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) { if (num_machines_ <= 1) { Log::Fatal("Please initilize the network interface first"); } if (reduce_scatter_ext_fun_ != nullptr) { return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer); } if (!recursive_halving_map_.is_prof2) { int remain = recursive_halving_map_.num_remain; std::vector rcsv_block_start(1 << recursive_halving_map_.k); std::vector rcsv_block_len(1 << recursive_halving_map_.k); std::vector real_ranks; int brush = 0; // build block_start and block_len for remain powers of 2 workers for (int i = 0; i < num_machines_; ++i) { if ((i < 2 * remain) && (i % 2 != 0)) { real_ranks.push_back(i); rcsv_block_start[i - 1 - brush] = block_start[i - 1]; rcsv_block_len[i - 1 - brush] = block_len[i] + block_len[i - 1]; brush++; } if (i >= 2 * remain) { real_ranks.push_back(i); rcsv_block_start[i - remain] = block_start[i]; rcsv_block_len[i - remain] = block_len[i]; } } // if local rank is remain, send local data to rank+1 if (rank_ < 2 * remain) { if (rank_ % 2 == 0) { linkers_->Send(rank_ + 1, input, input_size); } else { linkers_->Recv(rank_ - 1, output, input_size); reducer(output, input, type_size, input_size); } } // excute recursize halving algorithm for powers of 2 workers if (recursive_halving_map_.virtual_rank != -1) { for (int i = 0; i < recursive_halving_map_.k; ++i) { int virtual_rank = recursive_halving_map_.ranks[i]; int target = real_ranks[virtual_rank]; int send_block_start = recursive_halving_map_.send_block_start[i]; int recv_block_start = recursive_halving_map_.recv_block_start[i]; // get send information int send_size = 0; for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) { send_size += rcsv_block_len[send_block_start + j]; } // get recv information int need_recv_cnt = 0; for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) { need_recv_cnt += rcsv_block_len[recv_block_start + j]; } // send and recv at same time linkers_->SendRecv(target, input + rcsv_block_start[send_block_start], send_size, target, output, need_recv_cnt); // reduce reducer(output, input + rcsv_block_start[recv_block_start], type_size, need_recv_cnt); } } // send result back to remain workers if (rank_ < 2 * remain) { if (rank_ % 2 != 0) { linkers_->Send(rank_ - 1, input + block_start[rank_ - 1], block_len[rank_ - 1]); } else { linkers_->Recv(rank_ + 1, input + block_start[rank_], block_len[rank_]); } } } else { for (int i = 0; i < recursive_halving_map_.k; ++i) { // get target int target = recursive_halving_map_.ranks[i]; comm_size_t send_block_start = recursive_halving_map_.send_block_start[i]; comm_size_t recv_block_start = recursive_halving_map_.recv_block_start[i]; // get send information comm_size_t send_size = 0; for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) { send_size += block_len[send_block_start + j]; } // get recv information comm_size_t need_recv_cnt = 0; for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) { need_recv_cnt += block_len[recv_block_start + j]; } // send and recv at same time linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt); // reduce reducer(output, input + block_start[recv_block_start], type_size, need_recv_cnt); } } // copy result std::memcpy(output, input + block_start[rank_], block_len[rank_]); } } // namespace LightGBM