Unverified Commit 5afffa76 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine allgather (#1175)

* refine allgather.

* fix a bug.
parent 7d35beec
...@@ -129,6 +129,13 @@ public: ...@@ -129,6 +129,13 @@ public:
*/ */
static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size); static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
/*! /*!
* \brief Perform reduce scatter by using recursive halving algorithm. * \brief Perform reduce scatter by using recursive halving algorithm.
Communication times is O(log(n)), and communication cost is O(input_size) Communication times is O(log(n)), and communication cost is O(input_size)
......
...@@ -137,6 +137,19 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_ ...@@ -137,6 +137,19 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_
if (allgather_ext_fun_ != nullptr) { if (allgather_ext_fun_ != nullptr) {
return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size); return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
} }
const comm_size_t kRecursiveDoublingThreshold = 1024 * 1024; // 1MB
const comm_size_t kBruckThreshold = 512 * 1024; // 512KB
const bool is_power_of2 = (num_machines_ & (num_machines_ - 1)) == 0;
if (is_power_of2 && all_size < kRecursiveDoublingThreshold) {
AllgatherRecursiveDoubling(input, block_start, block_len, output, all_size);
} else if (all_size < kBruckThreshold) {
AllgatherBruck(input, block_start, block_len, output, all_size);
} else {
AllgatherRing(input, block_start, block_len, output, all_size);
}
}
void Network::AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) {
comm_size_t write_pos = 0; comm_size_t write_pos = 0;
// use output as receive buffer // use output as receive buffer
std::memcpy(output, input, block_len[rank_]); std::memcpy(output, input, block_len[rank_]);
...@@ -168,6 +181,50 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_ ...@@ -168,6 +181,50 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_
std::reverse<char*>(output + block_start[rank_], output + all_size); std::reverse<char*>(output + block_start[rank_], output + all_size);
} }
void Network::AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t) {
// use output as receive buffer
std::memcpy(output + block_start[rank_], input, block_len[rank_]);
for (int i = 0; i < bruck_map_.k; ++i) {
// get current local block size
int cur_step = 1 << i;
const int vgroup = rank_ / cur_step;
const int vrank = vgroup * cur_step;
int target = rank_ + cur_step;
int target_vrank = (vgroup + 1) * cur_step;
if (vgroup & 1) {
target = rank_ - cur_step;
target_vrank = (vgroup - 1) * cur_step;
}
// get send information
comm_size_t need_send_len = 0;
// get recv information
comm_size_t need_recv_len = 0;
for (int j = 0; j < cur_step; ++j) {
need_send_len += block_len[(vrank + j)];
need_recv_len += block_len[(target_vrank + j)];
}
// send and recv at same time
linkers_->SendRecv(target, output + block_start[vrank], need_send_len,
target, output + block_start[target_vrank], need_recv_len);
}
}
void Network::AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t) {
// use output as receive buffer
std::memcpy(output + block_start[rank_], input, block_len[rank_]);
int out_rank = (rank_ + 1) % num_machines_;
int in_rank = (rank_ - 1 + num_machines_) % num_machines_;
int out_place = rank_;
int in_place = in_rank;
for (int i = 1; i < num_machines_; ++i) {
// send and recv at same time
linkers_->SendRecv(out_rank, output + block_start[out_place], block_len[out_place],
in_rank, output + block_start[in_place], block_len[in_place]);
out_place = (out_place - 1 + num_machines_) % num_machines_;
in_place = (in_place - 1 + num_machines_) % num_machines_;
}
}
void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) { void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment