Commit ee8a65ae authored by zhangjin's avatar zhangjin Committed by Guolin Ke
Browse files

update network for non powers of 2 workers (#1178)

* update network.h

Improve training speed in parallel learning where workers is not powers of 2

* update linders_socket.cpp

Improve training speed in parallel learning where workers is not powers of 2

* update linder_topo.cpp

Improve training speed in parallel learning where workers is not powers of 2

* update network.cpp

Improve training speed in parallel learning where workers is not powers of 2

* update linder_topo.cpp

fix a bug
parent 5afffa76
...@@ -38,9 +38,14 @@ public: ...@@ -38,9 +38,14 @@ public:
/*! \brief Network structure for recursive halving algorithm */ /*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap { class RecursiveHalvingMap {
public: public:
bool need_pairwise; /*! \brief If number workers is powers of 2 */
bool is_prof2;
/*! \brief Communication times for one recursize halving algorithm */ /*! \brief Communication times for one recursize halving algorithm */
int k; int k;
/*! \brief Number workers subtract powers of 2 */
int num_remain;
/*! \brief Virtual rank for recursize halving algorithm */
int virtual_rank;
/*! \brief ranks[i] means the machines that will communicate with on i-th communication*/ /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
std::vector<int> ranks; std::vector<int> ranks;
/*! \brief send_block_start[i] means send block start index at i-th communication*/ /*! \brief send_block_start[i] means send block start index at i-th communication*/
...@@ -54,7 +59,7 @@ public: ...@@ -54,7 +59,7 @@ public:
RecursiveHalvingMap(); RecursiveHalvingMap();
RecursiveHalvingMap(int k, bool in_need_pairwise); RecursiveHalvingMap(int k, int num_remain, int virtual_rank, bool is_prof2);
/*! /*!
* \brief Create the object of recursive halving map * \brief Create the object of recursive halving map
......
...@@ -44,21 +44,22 @@ BruckMap BruckMap::Construct(int rank, int num_machines) { ...@@ -44,21 +44,22 @@ BruckMap BruckMap::Construct(int rank, int num_machines) {
RecursiveHalvingMap::RecursiveHalvingMap() { RecursiveHalvingMap::RecursiveHalvingMap() {
k = 0; k = 0;
need_pairwise = true; is_prof2 = true;
num_remain = 0;
} }
RecursiveHalvingMap::RecursiveHalvingMap(int in_k, bool in_need_pairwise) { RecursiveHalvingMap::RecursiveHalvingMap(int in_k, int in_remain, int in_rank, bool is_power_of2) {
need_pairwise = in_need_pairwise;
k = in_k; k = in_k;
if (!need_pairwise) { is_prof2 = is_power_of2;
for (int i = 0; i < k; ++i) { num_remain = in_remain;
// defalut set as -1 virtual_rank = in_rank;
ranks.push_back(-1); for (int i = 0; i < k; ++i) {
send_block_start.push_back(-1); // defalut set as -1
send_block_len.push_back(-1); ranks.push_back(-1);
recv_block_start.push_back(-1); send_block_start.push_back(-1);
recv_block_len.push_back(-1); send_block_len.push_back(-1);
} recv_block_start.push_back(-1);
recv_block_len.push_back(-1);
} }
} }
...@@ -74,17 +75,31 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) { ...@@ -74,17 +75,31 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
distance.push_back(1 << (k - 1 - i)); distance.push_back(1 << (k - 1 - i));
} }
if ((1 << k) == num_machines) { int remain = num_machines - (1 << k);
RecursiveHalvingMap rec_map(k, false); int virtual_rank = rank;
// if num_machines = 2^k, don't need to group machines // if virtual_rank not -1 will not excute recursize halving algorithm
if (rank < 2 * remain) {
if (rank % 2 == 0) {
virtual_rank = -1;
} else {
virtual_rank = rank / 2;
}
} else {
virtual_rank = rank - remain;
}
bool is_power_of2 = false;
if ((1 << k) == num_machines) { is_power_of2 = true; }
RecursiveHalvingMap rec_map(k, remain, virtual_rank, is_power_of2);
if (virtual_rank != -1) {
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
// communication direction, %2 == 0 is positive // communication direction, %2 == 0 is positive
const int dir = ((rank / distance[i]) % 2 == 0) ? 1 : -1; const int dir = ((virtual_rank / distance[i]) % 2 == 0) ? 1 : -1;
// neighbor at k-th communication // neighbor at k-th communication
const int next_node_idx = rank + dir * distance[i]; const int next_node_idx = virtual_rank + dir * distance[i];
rec_map.ranks[i] = next_node_idx; rec_map.ranks[i] = next_node_idx;
// receive data block at k-th communication // receive data block at k-th communication
const int recv_block_start = rank / distance[i]; const int recv_block_start = virtual_rank / distance[i];
rec_map.recv_block_start[i] = recv_block_start * distance[i]; rec_map.recv_block_start[i] = recv_block_start * distance[i];
rec_map.recv_block_len[i] = distance[i]; rec_map.recv_block_len[i] = distance[i];
// send data block at k-th communication // send data block at k-th communication
...@@ -92,10 +107,8 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) { ...@@ -92,10 +107,8 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
rec_map.send_block_start[i] = send_block_start * distance[i]; rec_map.send_block_start[i] = send_block_start * distance[i];
rec_map.send_block_len[i] = distance[i]; rec_map.send_block_len[i] = distance[i];
} }
return rec_map;
} else {
return RecursiveHalvingMap(k, true);
} }
return rec_map;
} }
} // namespace LightGBM } // namespace LightGBM
......
...@@ -163,7 +163,7 @@ void Linkers::ListenThread(int incoming_cnt) { ...@@ -163,7 +163,7 @@ void Linkers::ListenThread(int incoming_cnt) {
void Linkers::Construct() { void Linkers::Construct() {
// save ranks that need to connect with // save ranks that need to connect with
std::unordered_map<int, int> need_connect; std::unordered_map<int, int> need_connect;
if (recursive_halving_map_.need_pairwise) { if (!recursive_halving_map_.is_prof2) {
for (int i = 0; i < num_machines_; ++i) { for (int i = 0; i < num_machines_; ++i) {
if (i != rank_) { if (i != rank_) {
need_connect[i] = 1; need_connect[i] = 1;
......
...@@ -225,19 +225,73 @@ void Network::AllgatherRing(char* input, const comm_size_t* block_start, const c ...@@ -225,19 +225,73 @@ void Network::AllgatherRing(char* input, const comm_size_t* block_start, const c
} }
} }
void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) { void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start,
const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (reduce_scatter_ext_fun_ != nullptr) { if (reduce_scatter_ext_fun_ != nullptr) {
return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer); return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
} }
if (recursive_halving_map_.need_pairwise) { if (!recursive_halving_map_.is_prof2) {
for (int i = 1; i < num_machines_; ++i) { int remain = recursive_halving_map_.num_remain;
int out_rank = (rank_ + i) % num_machines_; std::vector<int> rcsv_block_start(1 << recursive_halving_map_.k);
int in_rank = (rank_ - i + num_machines_) % num_machines_; std::vector<int> rcsv_block_len(1 << recursive_halving_map_.k);
linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]); std::vector<int> real_ranks;
reducer(output, input + block_start[rank_], type_size, block_len[rank_]); int brush = 0;
// build block_start and block_len for remain powers of 2 workers
for (int i = 0; i < num_machines_; ++i) {
if ((i < 2 * remain) && (i % 2 != 0)) {
real_ranks.push_back(i);
rcsv_block_start[i - 1 - brush] = block_start[i - 1];
rcsv_block_len[i - 1 - brush] = block_len[i] + block_len[i - 1];
brush++;
}
if (i >= 2 * remain) {
real_ranks.push_back(i);
rcsv_block_start[i - remain] = block_start[i];
rcsv_block_len[i - remain] = block_len[i];
}
}
// if local rank is remain, send local data to rank+1
if (rank_ < 2 * remain) {
if (rank_ % 2 == 0) {
linkers_->Send(rank_ + 1, input, input_size);
} else {
linkers_->Recv(rank_ - 1, output, input_size);
reducer(output, input, type_size, input_size);
}
}
// excute recursize halving algorithm for powers of 2 workers
if (recursive_halving_map_.virtual_rank != -1) {
for (int i = 0; i < recursive_halving_map_.k; ++i) {
int virtual_rank = recursive_halving_map_.ranks[i];
int target = real_ranks[virtual_rank];
int send_block_start = recursive_halving_map_.send_block_start[i];
int recv_block_start = recursive_halving_map_.recv_block_start[i];
// get send information
int send_size = 0;
for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
send_size += rcsv_block_len[send_block_start + j];
}
// get recv information
int need_recv_cnt = 0;
for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
need_recv_cnt += rcsv_block_len[recv_block_start + j];
}
// send and recv at same time
linkers_->SendRecv(target, input + rcsv_block_start[send_block_start], send_size, target, output, need_recv_cnt);
// reduce
reducer(output, input + rcsv_block_start[recv_block_start], type_size, need_recv_cnt);
}
}
// send result back to remain workers
if (rank_ < 2 * remain) {
if (rank_ % 2 != 0) {
linkers_->Send(rank_ - 1, input + block_start[rank_ - 1], block_len[rank_ - 1]);
} else {
linkers_->Recv(rank_ + 1, input + block_start[rank_], block_len[rank_]);
}
} }
} else { } else {
for (int i = 0; i < recursive_halving_map_.k; ++i) { for (int i = 0; i < recursive_halving_map_.k; ++i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment