Commit 7737b45e authored by Guolin Ke's avatar Guolin Ke
Browse files

refine reduce scatter.

parent dc226c27
......@@ -35,30 +35,12 @@ public:
static BruckMap Construct(int rank, int num_machines);
};
/*!
* \brief node type on recursive halving algorithm
* When number of machines is not power of 2, need group machines into power of 2 group.
* And we can let each group has at most 2 machines.
* if the group only has 1 machine. this machine is the normal node
* if the group has 2 machines, this group will have two type of nodes, one is the leader.
* leader will represent this group and communication with others.
*/
enum RecursiveHalvingNodeType {
Normal, // normal node, 1 group only have 1 machine
GroupLeader, // leader of group when number of machines in this group is 2.
Other // non-leader machines in group
};
/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
public:
bool need_pairwise;
/*! \brief Communication times for one recursize halving algorithm */
int k;
/*! \brief Node type */
RecursiveHalvingNodeType type;
/*! \brief Neighbor, only used for non-normal node*/
int neighbor;
/*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
std::vector<int> ranks;
/*! \brief send_block_start[i] means send block start index at i-th communication*/
......@@ -72,7 +54,7 @@ public:
RecursiveHalvingMap();
RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n);
RecursiveHalvingMap(int k, bool in_need_pairwise);
/*!
* \brief Create the object of recursive halving map
......
......@@ -42,15 +42,16 @@ BruckMap BruckMap::Construct(int rank, int num_machines) {
return bruckMap;
}
RecursiveHalvingMap::RecursiveHalvingMap() {
k = 0;
need_pairwise = true;
}
RecursiveHalvingMap::RecursiveHalvingMap(RecursiveHalvingNodeType _type, int n) {
type = _type;
k = n;
if (type != RecursiveHalvingNodeType::Other) {
for (int i = 0; i < n; ++i) {
RecursiveHalvingMap::RecursiveHalvingMap(int in_k, bool in_need_pairwise) {
need_pairwise = in_need_pairwise;
k = in_k;
if (!need_pairwise) {
for (int i = 0; i < k; ++i) {
// defalut set as -1
ranks.push_back(-1);
send_block_start.push_back(-1);
......@@ -74,7 +75,7 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
}
if ((1 << k) == num_machines) {
RecursiveHalvingMap rec_map(RecursiveHalvingNodeType::Normal, k);
RecursiveHalvingMap rec_map(k, false);
// if num_machines = 2^k, don't need to group machines
for (int i = 0; i < k; ++i) {
// communication direction, %2 == 0 is positive
......@@ -93,82 +94,7 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
}
return rec_map;
} else {
// if num_machines != 2^k, need to group machines
int lower_power_of_2 = 1 << k;
int rest = num_machines - lower_power_of_2;
std::vector<RecursiveHalvingNodeType> node_type;
for (int i = 0; i < num_machines; ++i) {
node_type.push_back(RecursiveHalvingNodeType::Normal);
}
// group, two machine in one group, total "rest" groups will have 2 machines.
for (int i = 0; i < rest; ++i) {
int right = num_machines - i * 2 - 1;
int left = num_machines - i * 2 - 2;
// let left machine as group leader
node_type[left] = RecursiveHalvingNodeType::GroupLeader;
node_type[right] = RecursiveHalvingNodeType::Other;
}
int group_cnt = 0;
// cache block information for groups, group with 2 machines will have double block size
std::vector<int> group_block_start(lower_power_of_2);
std::vector<int> group_block_len(lower_power_of_2, 0);
// convert from group to node leader
std::vector<int> group_to_node(lower_power_of_2);
// convert from node to group
std::vector<int> node_to_group(num_machines);
for (int i = 0; i < num_machines; ++i) {
// meet new group
if (node_type[i] == RecursiveHalvingNodeType::Normal || node_type[i] == RecursiveHalvingNodeType::GroupLeader) {
group_to_node[group_cnt++] = i;
}
node_to_group[i] = group_cnt - 1;
// add block len for this group
group_block_len[group_cnt - 1]++;
}
// calculate the group block start
group_block_start[0] = 0;
for (int i = 1; i < lower_power_of_2; ++i) {
group_block_start[i] = group_block_start[i - 1] + group_block_len[i - 1];
}
RecursiveHalvingMap rec_map(node_type[rank], k);
if (node_type[rank] == RecursiveHalvingNodeType::Other) {
rec_map.neighbor = rank - 1;
// not need to construct
return rec_map;
}
if (node_type[rank] == RecursiveHalvingNodeType::GroupLeader) {
rec_map.neighbor = rank + 1;
}
const int cur_group_idx = node_to_group[rank];
for (int i = 0; i < k; ++i) {
const int dir = ((cur_group_idx / distance[i]) % 2 == 0) ? 1 : -1;
const int next_node_idx = group_to_node[(cur_group_idx + dir * distance[i])];
rec_map.ranks[i] = next_node_idx;
// get receive block informations
const int recv_block_start = cur_group_idx / distance[i];
rec_map.recv_block_start[i] = group_block_start[recv_block_start * distance[i]];
int recv_block_len = 0;
// accumulate block len
for (int j = 0; j < distance[i]; ++j) {
recv_block_len += group_block_len[recv_block_start * distance[i] + j];
}
rec_map.recv_block_len[i] = recv_block_len;
// get send block informations
const int send_block_start = (cur_group_idx + dir * distance[i]) / distance[i];
rec_map.send_block_start[i] = group_block_start[send_block_start * distance[i]];
int send_block_len = 0;
// accumulate block len
for (int j = 0; j < distance[i]; ++j) {
send_block_len += group_block_len[send_block_start * distance[i] + j];
}
rec_map.send_block_len[i] = send_block_len;
}
return rec_map;
return RecursiveHalvingMap(k, true);
}
}
......
......@@ -151,14 +151,17 @@ void Linkers::ListenThread(int incoming_cnt) {
void Linkers::Construct() {
// save ranks that need to connect with
std::unordered_map<int, int> need_connect;
for (int i = 0; i < bruck_map_.k; ++i) {
need_connect[bruck_map_.out_ranks[i]] = 1;
need_connect[bruck_map_.in_ranks[i]] = 1;
}
if (recursive_halving_map_.type != RecursiveHalvingNodeType::Normal) {
need_connect[recursive_halving_map_.neighbor] = 1;
}
if (recursive_halving_map_.type != RecursiveHalvingNodeType::Other) {
if (recursive_halving_map_.need_pairwise) {
for (int i = 0; i < num_machines_; ++i) {
if (i != rank_) {
need_connect[i] = 1;
}
}
} else {
for (int i = 0; i < bruck_map_.k; ++i) {
need_connect[bruck_map_.out_ranks[i]] = 1;
need_connect[bruck_map_.in_ranks[i]] = 1;
}
for (int i = 0; i < recursive_halving_map_.k; ++i) {
need_connect[recursive_halving_map_.ranks[i]] = 1;
}
......@@ -175,6 +178,7 @@ void Linkers::Construct() {
++incoming_cnt;
}
}
// start listener
listener_->SetTimeout(socket_timeout_);
listener_->Listen(incoming_cnt);
......
......@@ -107,21 +107,19 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
int cur_block_size = std::min(1 << i, num_machines_ - accumulated_block);
// get out rank
int out_rank = bruck_map_.out_ranks[i];
// get send information
int send_len = 0;
for (int j = 0; j < cur_block_size; ++j) {
send_len += block_len[(rank_ + j) % num_machines_];
}
// get in rank
int in_rank = bruck_map_.in_ranks[i];
// get send information
int need_send_len = 0;
// get recv information
int need_recv_cnt = 0;
int need_recv_len = 0;
for (int j = 0; j < cur_block_size; ++j) {
need_recv_cnt += block_len[(rank_ + accumulated_block + j) % num_machines_];
need_send_len += block_len[(rank_ + j) % num_machines_];
need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_];
}
// send and recv at same time
linkers_->SendRecv(out_rank, output, send_len, in_rank, output + write_pos, need_recv_cnt);
write_pos += need_recv_cnt;
linkers_->SendRecv(out_rank, output, need_send_len, in_rank, output + write_pos, need_recv_len);
write_pos += need_recv_len;
accumulated_block += cur_block_size;
}
// rotate in-place
......@@ -130,22 +128,15 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
std::reverse<char*>(output + block_start[rank_], output + all_size);
}
void Network::ReduceScatter(char* input, int input_size, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) {
bool is_powerof_2 = (num_machines_ & (num_machines_ - 1)) == 0;
if (!is_powerof_2) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) {
// send local data to neighbor first
linkers_->Send(recursive_halving_map_.neighbor, input, input_size);
} else if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) {
// receive neighbor data first
int need_recv_cnt = input_size;
linkers_->Recv(recursive_halving_map_.neighbor, output, need_recv_cnt);
// reduce
reducer(output, input, input_size);
void Network::ReduceScatter(char* input, int, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) {
if (recursive_halving_map_.need_pairwise) {
for (int i = 1; i < num_machines_; ++i) {
int out_rank = (rank_ + i) % num_machines_;
int in_rank = (rank_ - i + num_machines_) % num_machines_;
linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]);
reducer(output, input + block_start[rank_], block_len[rank_]);
}
}
// start recursive halfing
if (recursive_halving_map_.type != RecursiveHalvingNodeType::Other) {
} else {
for (int i = 0; i < recursive_halving_map_.k; ++i) {
// get target
int target = recursive_halving_map_.ranks[i];
......@@ -167,19 +158,6 @@ void Network::ReduceScatter(char* input, int input_size, const int* block_start,
reducer(output, input + block_start[recv_block_start], need_recv_cnt);
}
}
if (!is_powerof_2) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) {
// send result to neighbor
linkers_->Send(recursive_halving_map_.neighbor,
input + block_start[recursive_halving_map_.neighbor],
block_len[recursive_halving_map_.neighbor]);
} else if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) {
// receive result from neighbor
int need_recv_cnt = block_len[rank_];
linkers_->Recv(recursive_halving_map_.neighbor, output, need_recv_cnt);
return;
}
}
// copy result
std::memcpy(output, input + block_start[rank_], block_len[rank_]);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment