"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "796ba8033266931814fe27cdf729fcf0d79fff5d"
Commit 7737b45e authored by Guolin Ke's avatar Guolin Ke
Browse files

refine reduce scatter.

parent dc226c27
...@@ -35,30 +35,12 @@ public: ...@@ -35,30 +35,12 @@ public:
static BruckMap Construct(int rank, int num_machines); static BruckMap Construct(int rank, int num_machines);
}; };
/*!
* \brief node type on recursive halving algorithm
* When number of machines is not power of 2, need group machines into power of 2 group.
* And we can let each group has at most 2 machines.
* if the group only has 1 machine. this machine is the normal node
* if the group has 2 machines, this group will have two type of nodes, one is the leader.
* leader will represent this group and communication with others.
*/
enum RecursiveHalvingNodeType {
Normal, // normal node, 1 group only have 1 machine
GroupLeader, // leader of group when number of machines in this group is 2.
Other // non-leader machines in group
};
/*! \brief Network structure for recursive halving algorithm */ /*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap { class RecursiveHalvingMap {
public: public:
bool need_pairwise;
/*! \brief Communication times for one recursize halving algorithm */ /*! \brief Communication times for one recursize halving algorithm */
int k; int k;
/*! \brief Node type */
RecursiveHalvingNodeType type;
/*! \brief Neighbor, only used for non-normal node*/
int neighbor;
/*! \brief ranks[i] means the machines that will communicate with on i-th communication*/ /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
std::vector<int> ranks; std::vector<int> ranks;
/*! \brief send_block_start[i] means send block start index at i-th communication*/ /*! \brief send_block_start[i] means send block start index at i-th communication*/
...@@ -72,7 +54,7 @@ public: ...@@ -72,7 +54,7 @@ public:
RecursiveHalvingMap(); RecursiveHalvingMap();
RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n); RecursiveHalvingMap(int k, bool in_need_pairwise);
/*! /*!
* \brief Create the object of recursive halving map * \brief Create the object of recursive halving map
......
...@@ -42,15 +42,16 @@ BruckMap BruckMap::Construct(int rank, int num_machines) { ...@@ -42,15 +42,16 @@ BruckMap BruckMap::Construct(int rank, int num_machines) {
return bruckMap; return bruckMap;
} }
RecursiveHalvingMap::RecursiveHalvingMap() { RecursiveHalvingMap::RecursiveHalvingMap() {
k = 0; k = 0;
need_pairwise = true;
} }
RecursiveHalvingMap::RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n) {
type = _type; RecursiveHalvingMap::RecursiveHalvingMap(int in_k, bool in_need_pairwise) {
k = n; need_pairwise = in_need_pairwise;
if (type != RecursiveHalvingNodeType::Other) { k = in_k;
for (int i = 0; i < n; ++i) { if (!need_pairwise) {
for (int i = 0; i < k; ++i) {
// defalut set as -1 // defalut set as -1
ranks.push_back(-1); ranks.push_back(-1);
send_block_start.push_back(-1); send_block_start.push_back(-1);
...@@ -74,7 +75,7 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) { ...@@ -74,7 +75,7 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
} }
if ((1 << k) == num_machines) { if ((1 << k) == num_machines) {
RecursiveHalvingMap rec_map(RecursiveHalvingNodeType::Normal, k); RecursiveHalvingMap rec_map(k, false);
// if num_machines = 2^k, don't need to group machines // if num_machines = 2^k, don't need to group machines
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
// communication direction, %2 == 0 is positive // communication direction, %2 == 0 is positive
...@@ -93,82 +94,7 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) { ...@@ -93,82 +94,7 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
} }
return rec_map; return rec_map;
} else { } else {
// if num_machines != 2^k, need to group machines return RecursiveHalvingMap(k, true);
int lower_power_of_2 = 1 << k;
int rest = num_machines - lower_power_of_2;
std::vector<RecursiveHalvingNodeType> node_type;
for (int i = 0; i < num_machines; ++i) {
node_type.push_back(RecursiveHalvingNodeType::Normal);
}
// group, two machine in one group, total "rest" groups will have 2 machines.
for (int i = 0; i < rest; ++i) {
int right = num_machines - i * 2 - 1;
int left = num_machines - i * 2 - 2;
// let left machine as group leader
node_type[left] = RecursiveHalvingNodeType::GroupLeader;
node_type[right] = RecursiveHalvingNodeType::Other;
}
int group_cnt = 0;
// cache block information for groups, group with 2 machines will have double block size
std::vector<int> group_block_start(lower_power_of_2);
std::vector<int> group_block_len(lower_power_of_2, 0);
// convert from group to node leader
std::vector<int> group_to_node(lower_power_of_2);
// convert from node to group
std::vector<int> node_to_group(num_machines);
for (int i = 0; i < num_machines; ++i) {
// meet new group
if (node_type[i] == RecursiveHalvingNodeType::Normal || node_type[i] == RecursiveHalvingNodeType::GroupLeader) {
group_to_node[group_cnt++] = i;
}
node_to_group[i] = group_cnt - 1;
// add block len for this group
group_block_len[group_cnt - 1]++;
}
// calculate the group block start
group_block_start[0] = 0;
for (int i = 1; i < lower_power_of_2; ++i) {
group_block_start[i] = group_block_start[i - 1] + group_block_len[i - 1];
}
RecursiveHalvingMap rec_map(node_type[rank], k);
if (node_type[rank] == RecursiveHalvingNodeType::Other) {
rec_map.neighbor = rank - 1;
// not need to construct
return rec_map;
}
if (node_type[rank] == RecursiveHalvingNodeType::GroupLeader) {
rec_map.neighbor = rank + 1;
}
const int cur_group_idx = node_to_group[rank];
for (int i = 0; i < k; ++i) {
const int dir = ((cur_group_idx / distance[i]) % 2 == 0) ? 1 : -1;
const int next_node_idx = group_to_node[(cur_group_idx + dir * distance[i])];
rec_map.ranks[i] = next_node_idx;
// get receive block informations
const int recv_block_start = cur_group_idx / distance[i];
rec_map.recv_block_start[i] = group_block_start[recv_block_start * distance[i]];
int recv_block_len = 0;
// accumulate block len
for (int j = 0; j < distance[i]; ++j) {
recv_block_len += group_block_len[recv_block_start * distance[i] + j];
}
rec_map.recv_block_len[i] = recv_block_len;
// get send block informations
const int send_block_start = (cur_group_idx + dir * distance[i]) / distance[i];
rec_map.send_block_start[i] = group_block_start[send_block_start * distance[i]];
int send_block_len = 0;
// accumulate block len
for (int j = 0; j < distance[i]; ++j) {
send_block_len += group_block_len[send_block_start * distance[i] + j];
}
rec_map.send_block_len[i] = send_block_len;
}
return rec_map;
} }
} }
......
...@@ -151,14 +151,17 @@ void Linkers::ListenThread(int incoming_cnt) { ...@@ -151,14 +151,17 @@ void Linkers::ListenThread(int incoming_cnt) {
void Linkers::Construct() { void Linkers::Construct() {
// save ranks that need to connect with // save ranks that need to connect with
std::unordered_map<int, int> need_connect; std::unordered_map<int, int> need_connect;
for (int i = 0; i < bruck_map_.k; ++i) { if (recursive_halving_map_.need_pairwise) {
need_connect[bruck_map_.out_ranks[i]] = 1; for (int i = 0; i < num_machines_; ++i) {
need_connect[bruck_map_.in_ranks[i]] = 1; if (i != rank_) {
} need_connect[i] = 1;
if (recursive_halving_map_.type != RecursiveHalvingNodeType::Normal) { }
need_connect[recursive_halving_map_.neighbor] = 1; }
} } else {
if (recursive_halving_map_.type != RecursiveHalvingNodeType::Other) { for (int i = 0; i < bruck_map_.k; ++i) {
need_connect[bruck_map_.out_ranks[i]] = 1;
need_connect[bruck_map_.in_ranks[i]] = 1;
}
for (int i = 0; i < recursive_halving_map_.k; ++i) { for (int i = 0; i < recursive_halving_map_.k; ++i) {
need_connect[recursive_halving_map_.ranks[i]] = 1; need_connect[recursive_halving_map_.ranks[i]] = 1;
} }
...@@ -175,6 +178,7 @@ void Linkers::Construct() { ...@@ -175,6 +178,7 @@ void Linkers::Construct() {
++incoming_cnt; ++incoming_cnt;
} }
} }
// start listener // start listener
listener_->SetTimeout(socket_timeout_); listener_->SetTimeout(socket_timeout_);
listener_->Listen(incoming_cnt); listener_->Listen(incoming_cnt);
......
...@@ -107,21 +107,19 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const ...@@ -107,21 +107,19 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
int cur_block_size = std::min(1 << i, num_machines_ - accumulated_block); int cur_block_size = std::min(1 << i, num_machines_ - accumulated_block);
// get out rank // get out rank
int out_rank = bruck_map_.out_ranks[i]; int out_rank = bruck_map_.out_ranks[i];
// get send information
int send_len = 0;
for (int j = 0; j < cur_block_size; ++j) {
send_len += block_len[(rank_ + j) % num_machines_];
}
// get in rank // get in rank
int in_rank = bruck_map_.in_ranks[i]; int in_rank = bruck_map_.in_ranks[i];
// get send information
int need_send_len = 0;
// get recv information // get recv information
int need_recv_cnt = 0; int need_recv_len = 0;
for (int j = 0; j < cur_block_size; ++j) { for (int j = 0; j < cur_block_size; ++j) {
need_recv_cnt += block_len[(rank_ + accumulated_block + j) % num_machines_]; need_send_len += block_len[(rank_ + j) % num_machines_];
need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_];
} }
// send and recv at same time // send and recv at same time
linkers_->SendRecv(out_rank, output, send_len, in_rank, output + write_pos, need_recv_cnt); linkers_->SendRecv(out_rank, output, need_send_len, in_rank, output + write_pos, need_recv_len);
write_pos += need_recv_cnt; write_pos += need_recv_len;
accumulated_block += cur_block_size; accumulated_block += cur_block_size;
} }
// rotate in-place // rotate in-place
...@@ -130,22 +128,15 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const ...@@ -130,22 +128,15 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
std::reverse<char*>(output + block_start[rank_], output + all_size); std::reverse<char*>(output + block_start[rank_], output + all_size);
} }
void Network::ReduceScatter(char* input, int input_size, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) { void Network::ReduceScatter(char* input, int, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) {
bool is_powerof_2 = (num_machines_ & (num_machines_ - 1)) == 0; if (recursive_halving_map_.need_pairwise) {
if (!is_powerof_2) { for (int i = 1; i < num_machines_; ++i) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) { int out_rank = (rank_ + i) % num_machines_;
// send local data to neighbor first int in_rank = (rank_ - i + num_machines_) % num_machines_;
linkers_->Send(recursive_halving_map_.neighbor, input, input_size); linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]);
} else if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) { reducer(output, input + block_start[rank_], block_len[rank_]);
// receive neighbor data first
int need_recv_cnt = input_size;
linkers_->Recv(recursive_halving_map_.neighbor, output, need_recv_cnt);
// reduce
reducer(output, input, input_size);
} }
} } else {
// start recursive halfing
if (recursive_halving_map_.type != RecursiveHalvingNodeType::Other) {
for (int i = 0; i < recursive_halving_map_.k; ++i) { for (int i = 0; i < recursive_halving_map_.k; ++i) {
// get target // get target
int target = recursive_halving_map_.ranks[i]; int target = recursive_halving_map_.ranks[i];
...@@ -167,19 +158,6 @@ void Network::ReduceScatter(char* input, int input_size, const int* block_start, ...@@ -167,19 +158,6 @@ void Network::ReduceScatter(char* input, int input_size, const int* block_start,
reducer(output, input + block_start[recv_block_start], need_recv_cnt); reducer(output, input + block_start[recv_block_start], need_recv_cnt);
} }
} }
if (!is_powerof_2) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) {
// send result to neighbor
linkers_->Send(recursive_halving_map_.neighbor,
input + block_start[recursive_halving_map_.neighbor],
block_len[recursive_halving_map_.neighbor]);
} else if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) {
// receive result from neighbor
int need_recv_cnt = block_len[rank_];
linkers_->Recv(recursive_halving_map_.neighbor, output, need_recv_cnt);
return;
}
}
// copy result // copy result
std::memcpy(output, input + block_start[rank_], block_len[rank_]); std::memcpy(output, input + block_start[rank_], block_len[rank_]);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment