Unverified Commit 7d3206e0 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Refine reduce scatter (#1179)

* unfinished

* some refines.

* fix a bug

* some refines

* remove useless files

* some small fixes

* clean code

* clean code

* fix connect

* fix a broken link

* fix a broken link
parent ee8a65ae
...@@ -762,7 +762,7 @@ You can specific query/group id in data file now. Please refer to parameter ``gr ...@@ -762,7 +762,7 @@ You can specific query/group id in data file now. Please refer to parameter ``gr
.. _AUC: https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve .. _AUC: https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
.. _log loss: https://www.kaggle.com/wiki/LogLoss .. _log loss: https://en.wikipedia.org/wiki/Cross_entropy
.. _softmax: https://en.wikipedia.org/wiki/Softmax_function .. _softmax: https://en.wikipedia.org/wiki/Softmax_function
......
...@@ -234,7 +234,7 @@ Examples ...@@ -234,7 +234,7 @@ Examples
.. _Quantile regression: https://en.wikipedia.org/wiki/Quantile_regression .. _Quantile regression: https://en.wikipedia.org/wiki/Quantile_regression
.. _log loss: https://www.kaggle.com/wiki/LogLoss .. _log loss: https://en.wikipedia.org/wiki/Cross_entropy
.. _softmax: https://en.wikipedia.org/wiki/Softmax_function .. _softmax: https://en.wikipedia.org/wiki/Softmax_function
......
...@@ -35,17 +35,29 @@ public: ...@@ -35,17 +35,29 @@ public:
static BruckMap Construct(int rank, int num_machines); static BruckMap Construct(int rank, int num_machines);
}; };
/*!
* \brief node type on recursive halving algorithm
* When number of machines is not power of 2, need group machines into power of 2 group.
* And we can let each group has at most 2 machines.
* if the group only has 1 machine. this machine is the normal node
* if the group has 2 machines, this group will have two type of nodes, one is the leader.
* leader will represent this group and communication with others.
*/
enum RecursiveHalvingNodeType {
Normal, // normal node, 1 group only have 1 machine
GroupLeader, // leader of group when number of machines in this group is 2.
Other // non-leader machines in group
};
/*! \brief Network structure for recursive halving algorithm */ /*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap { class RecursiveHalvingMap {
public: public:
/*! \brief If number workers is powers of 2 */
bool is_prof2;
/*! \brief Communication times for one recursize halving algorithm */ /*! \brief Communication times for one recursize halving algorithm */
int k; int k;
/*! \brief Number workers subtract powers of 2 */ /*! \brief Node type */
int num_remain; RecursiveHalvingNodeType type;
/*! \brief Virtual rank for recursize halving algorithm */ bool is_power_of_2;
int virtual_rank; int neighbor;
/*! \brief ranks[i] means the machines that will communicate with on i-th communication*/ /*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
std::vector<int> ranks; std::vector<int> ranks;
/*! \brief send_block_start[i] means send block start index at i-th communication*/ /*! \brief send_block_start[i] means send block start index at i-th communication*/
...@@ -59,7 +71,7 @@ public: ...@@ -59,7 +71,7 @@ public:
RecursiveHalvingMap(); RecursiveHalvingMap();
RecursiveHalvingMap(int k, int num_remain, int virtual_rank, bool is_prof2); RecursiveHalvingMap(int k, RecursiveHalvingNodeType _type, bool _is_power_of_2);
/*! /*!
* \brief Create the object of recursive halving map * \brief Create the object of recursive halving map
...@@ -134,13 +146,6 @@ public: ...@@ -134,13 +146,6 @@ public:
*/ */
static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size); static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
/*! /*!
* \brief Perform reduce scatter by using recursive halving algorithm. * \brief Perform reduce scatter by using recursive halving algorithm.
Communication times is O(log(n)), and communication cost is O(input_size) Communication times is O(log(n)), and communication cost is O(input_size)
...@@ -206,6 +211,21 @@ public: ...@@ -206,6 +211,21 @@ public:
} }
private: private:
static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
static void ReduceScatterRecursiveHalving(char* input, comm_size_t input_size, int type_size,
const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
const ReduceFunction& reducer);
static void ReduceScatterRing(char* input, comm_size_t input_size, int type_size,
const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
const ReduceFunction& reducer);
/*! \brief Number of all machines */ /*! \brief Number of all machines */
static THREAD_LOCAL int num_machines_; static THREAD_LOCAL int num_machines_;
/*! \brief Rank of local machine */ /*! \brief Rank of local machine */
......
...@@ -44,22 +44,21 @@ BruckMap BruckMap::Construct(int rank, int num_machines) { ...@@ -44,22 +44,21 @@ BruckMap BruckMap::Construct(int rank, int num_machines) {
RecursiveHalvingMap::RecursiveHalvingMap() { RecursiveHalvingMap::RecursiveHalvingMap() {
k = 0; k = 0;
is_prof2 = true;
num_remain = 0;
} }
RecursiveHalvingMap::RecursiveHalvingMap(int in_k, int in_remain, int in_rank, bool is_power_of2) { RecursiveHalvingMap::RecursiveHalvingMap(int in_k, RecursiveHalvingNodeType _type, bool _is_power_of_2) {
type = _type;
k = in_k; k = in_k;
is_prof2 = is_power_of2; is_power_of_2 = _is_power_of_2;
num_remain = in_remain; if (type != RecursiveHalvingNodeType::Other) {
virtual_rank = in_rank; for (int i = 0; i < k; ++i) {
for (int i = 0; i < k; ++i) { // defalut set as -1
// defalut set as -1 ranks.push_back(-1);
ranks.push_back(-1); send_block_start.push_back(-1);
send_block_start.push_back(-1); send_block_len.push_back(-1);
send_block_len.push_back(-1); recv_block_start.push_back(-1);
recv_block_start.push_back(-1); recv_block_len.push_back(-1);
recv_block_len.push_back(-1); }
} }
} }
...@@ -75,31 +74,17 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) { ...@@ -75,31 +74,17 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
distance.push_back(1 << (k - 1 - i)); distance.push_back(1 << (k - 1 - i));
} }
int remain = num_machines - (1 << k); if ((1 << k) == num_machines) {
int virtual_rank = rank; RecursiveHalvingMap rec_map(k, RecursiveHalvingNodeType::Normal, true);
// if virtual_rank not -1 will not excute recursize halving algorithm // if num_machines = 2^k, don't need to group machines
if (rank < 2 * remain) {
if (rank % 2 == 0) {
virtual_rank = -1;
} else {
virtual_rank = rank / 2;
}
} else {
virtual_rank = rank - remain;
}
bool is_power_of2 = false;
if ((1 << k) == num_machines) { is_power_of2 = true; }
RecursiveHalvingMap rec_map(k, remain, virtual_rank, is_power_of2);
if (virtual_rank != -1) {
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
// communication direction, %2 == 0 is positive // communication direction, %2 == 0 is positive
const int dir = ((virtual_rank / distance[i]) % 2 == 0) ? 1 : -1; const int dir = ((rank / distance[i]) % 2 == 0) ? 1 : -1;
// neighbor at k-th communication // neighbor at k-th communication
const int next_node_idx = virtual_rank + dir * distance[i]; const int next_node_idx = rank + dir * distance[i];
rec_map.ranks[i] = next_node_idx; rec_map.ranks[i] = next_node_idx;
// receive data block at k-th communication // receive data block at k-th communication
const int recv_block_start = virtual_rank / distance[i]; const int recv_block_start = rank / distance[i];
rec_map.recv_block_start[i] = recv_block_start * distance[i]; rec_map.recv_block_start[i] = recv_block_start * distance[i];
rec_map.recv_block_len[i] = distance[i]; rec_map.recv_block_len[i] = distance[i];
// send data block at k-th communication // send data block at k-th communication
...@@ -107,8 +92,85 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) { ...@@ -107,8 +92,85 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
rec_map.send_block_start[i] = send_block_start * distance[i]; rec_map.send_block_start[i] = send_block_start * distance[i];
rec_map.send_block_len[i] = distance[i]; rec_map.send_block_len[i] = distance[i];
} }
return rec_map;
} else {
// if num_machines != 2^k, need to group machines
int lower_power_of_2 = 1 << k;
int rest = num_machines - lower_power_of_2;
std::vector<RecursiveHalvingNodeType> node_type(num_machines);
for (int i = 0; i < num_machines; ++i) {
node_type[i] = RecursiveHalvingNodeType::Normal;
}
// group, two machine in one group, total "rest" groups will have 2 machines.
for (int i = 0; i < rest; ++i) {
int right = num_machines - i * 2 - 1;
int left = num_machines - i * 2 - 2;
// let left machine as group leader
node_type[left] = RecursiveHalvingNodeType::GroupLeader;
node_type[right] = RecursiveHalvingNodeType::Other;
}
int group_cnt = 0;
// cache block information for groups, group with 2 machines will have double block size
std::vector<int> group_block_start(lower_power_of_2);
std::vector<int> group_block_len(lower_power_of_2, 0);
// convert from group to node leader
std::vector<int> group_to_node(lower_power_of_2);
// convert from node to group
std::vector<int> node_to_group(num_machines);
for (int i = 0; i < num_machines; ++i) {
// meet new group
if (node_type[i] == RecursiveHalvingNodeType::Normal || node_type[i] == RecursiveHalvingNodeType::GroupLeader) {
group_to_node[group_cnt++] = i;
}
node_to_group[i] = group_cnt - 1;
// add block len for this group
group_block_len[group_cnt - 1]++;
}
// calculate the group block start
group_block_start[0] = 0;
for (int i = 1; i < lower_power_of_2; ++i) {
group_block_start[i] = group_block_start[i - 1] + group_block_len[i - 1];
}
RecursiveHalvingMap rec_map(k, node_type[rank], false);
if (node_type[rank] == RecursiveHalvingNodeType::Other) {
rec_map.neighbor = rank - 1;
// not need to construct
return rec_map;
}
if (node_type[rank] == RecursiveHalvingNodeType::GroupLeader) {
rec_map.neighbor = rank + 1;
}
const int cur_group_idx = node_to_group[rank];
for (int i = 0; i < k; ++i) {
const int dir = ((cur_group_idx / distance[i]) % 2 == 0) ? 1 : -1;
const int next_node_idx = group_to_node[(cur_group_idx + dir * distance[i])];
rec_map.ranks[i] = next_node_idx;
// get receive block informations
const int recv_block_start = cur_group_idx / distance[i];
rec_map.recv_block_start[i] = group_block_start[recv_block_start * distance[i]];
int recv_block_len = 0;
// accumulate block len
for (int j = 0; j < distance[i]; ++j) {
recv_block_len += group_block_len[recv_block_start * distance[i] + j];
}
rec_map.recv_block_len[i] = recv_block_len;
// get send block informations
const int send_block_start = (cur_group_idx + dir * distance[i]) / distance[i];
rec_map.send_block_start[i] = group_block_start[send_block_start * distance[i]];
int send_block_len = 0;
// accumulate block len
for (int j = 0; j < distance[i]; ++j) {
send_block_len += group_block_len[send_block_start * distance[i] + j];
}
rec_map.send_block_len[i] = send_block_len;
}
return rec_map;
} }
return rec_map;
} }
} // namespace LightGBM } // namespace LightGBM
......
...@@ -163,22 +163,11 @@ void Linkers::ListenThread(int incoming_cnt) { ...@@ -163,22 +163,11 @@ void Linkers::ListenThread(int incoming_cnt) {
void Linkers::Construct() { void Linkers::Construct() {
// save ranks that need to connect with // save ranks that need to connect with
std::unordered_map<int, int> need_connect; std::unordered_map<int, int> need_connect;
if (!recursive_halving_map_.is_prof2) { for (int i = 0; i < num_machines_; ++i) {
for (int i = 0; i < num_machines_; ++i) { if (i != rank_) {
if (i != rank_) { need_connect[i] = 1;
need_connect[i] = 1;
}
}
} else {
for (int i = 0; i < bruck_map_.k; ++i) {
need_connect[bruck_map_.out_ranks[i]] = 1;
need_connect[bruck_map_.in_ranks[i]] = 1;
}
for (int i = 0; i < recursive_halving_map_.k; ++i) {
need_connect[recursive_halving_map_.ranks[i]] = 1;
} }
} }
int need_connect_cnt = 0; int need_connect_cnt = 0;
int incoming_cnt = 0; int incoming_cnt = 0;
for (auto it = need_connect.begin(); it != need_connect.end(); ++it) { for (auto it = need_connect.begin(); it != need_connect.end(); ++it) {
......
...@@ -137,15 +137,15 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_ ...@@ -137,15 +137,15 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_
if (allgather_ext_fun_ != nullptr) { if (allgather_ext_fun_ != nullptr) {
return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size); return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
} }
const comm_size_t kRecursiveDoublingThreshold = 1024 * 1024; // 1MB const comm_size_t kRingThreshold = 10 * 1024 * 1024; // 10MB
const comm_size_t kBruckThreshold = 512 * 1024; // 512KB const int kRingNodeThreshold = 64;
const bool is_power_of2 = (num_machines_ & (num_machines_ - 1)) == 0; if (all_size > kRingThreshold && num_machines_ < kRingNodeThreshold) {
if (is_power_of2 && all_size < kRecursiveDoublingThreshold) { // when num_machines is small and data is large
AllgatherRing(input, block_start, block_len, output, all_size);
} else if (recursive_halving_map_.is_power_of_2) {
AllgatherRecursiveDoubling(input, block_start, block_len, output, all_size); AllgatherRecursiveDoubling(input, block_start, block_len, output, all_size);
} else if (all_size < kBruckThreshold) {
AllgatherBruck(input, block_start, block_len, output, all_size);
} else { } else {
AllgatherRing(input, block_start, block_len, output, all_size); AllgatherBruck(input, block_start, block_len, output, all_size);
} }
} }
...@@ -204,7 +204,7 @@ void Network::AllgatherRecursiveDoubling(char* input, const comm_size_t* block_s ...@@ -204,7 +204,7 @@ void Network::AllgatherRecursiveDoubling(char* input, const comm_size_t* block_s
need_recv_len += block_len[(target_vrank + j)]; need_recv_len += block_len[(target_vrank + j)];
} }
// send and recv at same time // send and recv at same time
linkers_->SendRecv(target, output + block_start[vrank], need_send_len, linkers_->SendRecv(target, output + block_start[vrank], need_send_len,
target, output + block_start[target_vrank], need_recv_len); target, output + block_start[target_vrank], need_recv_len);
} }
} }
...@@ -214,86 +214,50 @@ void Network::AllgatherRing(char* input, const comm_size_t* block_start, const c ...@@ -214,86 +214,50 @@ void Network::AllgatherRing(char* input, const comm_size_t* block_start, const c
std::memcpy(output + block_start[rank_], input, block_len[rank_]); std::memcpy(output + block_start[rank_], input, block_len[rank_]);
int out_rank = (rank_ + 1) % num_machines_; int out_rank = (rank_ + 1) % num_machines_;
int in_rank = (rank_ - 1 + num_machines_) % num_machines_; int in_rank = (rank_ - 1 + num_machines_) % num_machines_;
int out_place = rank_; int out_block = rank_;
int in_place = in_rank; int in_block = in_rank;
for (int i = 1; i < num_machines_; ++i) { for (int i = 1; i < num_machines_; ++i) {
// send and recv at same time // send and recv at same time
linkers_->SendRecv(out_rank, output + block_start[out_place], block_len[out_place], linkers_->SendRecv(out_rank, output + block_start[out_block], block_len[out_block],
in_rank, output + block_start[in_place], block_len[in_place]); in_rank, output + block_start[in_block], block_len[in_block]);
out_place = (out_place - 1 + num_machines_) % num_machines_; out_block = (out_block - 1 + num_machines_) % num_machines_;
in_place = (in_place - 1 + num_machines_) % num_machines_; in_block = (in_block - 1 + num_machines_) % num_machines_;
} }
} }
void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size,
const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) { const comm_size_t* block_start, const comm_size_t* block_len, char* output,
comm_size_t output_size, const ReduceFunction& reducer) {
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (reduce_scatter_ext_fun_ != nullptr) { if (reduce_scatter_ext_fun_ != nullptr) {
return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer); return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
} }
if (!recursive_halving_map_.is_prof2) { const comm_size_t kRingThreshold = 10 * 1024 * 1024; // 10MB
int remain = recursive_halving_map_.num_remain; if (recursive_halving_map_.is_power_of_2 || input_size < kRingThreshold) {
std::vector<int> rcsv_block_start(1 << recursive_halving_map_.k); ReduceScatterRecursiveHalving(input, input_size, type_size, block_start, block_len, output, output_size, reducer);
std::vector<int> rcsv_block_len(1 << recursive_halving_map_.k);
std::vector<int> real_ranks;
int brush = 0;
// build block_start and block_len for remain powers of 2 workers
for (int i = 0; i < num_machines_; ++i) {
if ((i < 2 * remain) && (i % 2 != 0)) {
real_ranks.push_back(i);
rcsv_block_start[i - 1 - brush] = block_start[i - 1];
rcsv_block_len[i - 1 - brush] = block_len[i] + block_len[i - 1];
brush++;
}
if (i >= 2 * remain) {
real_ranks.push_back(i);
rcsv_block_start[i - remain] = block_start[i];
rcsv_block_len[i - remain] = block_len[i];
}
}
// if local rank is remain, send local data to rank+1
if (rank_ < 2 * remain) {
if (rank_ % 2 == 0) {
linkers_->Send(rank_ + 1, input, input_size);
} else {
linkers_->Recv(rank_ - 1, output, input_size);
reducer(output, input, type_size, input_size);
}
}
// excute recursize halving algorithm for powers of 2 workers
if (recursive_halving_map_.virtual_rank != -1) {
for (int i = 0; i < recursive_halving_map_.k; ++i) {
int virtual_rank = recursive_halving_map_.ranks[i];
int target = real_ranks[virtual_rank];
int send_block_start = recursive_halving_map_.send_block_start[i];
int recv_block_start = recursive_halving_map_.recv_block_start[i];
// get send information
int send_size = 0;
for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
send_size += rcsv_block_len[send_block_start + j];
}
// get recv information
int need_recv_cnt = 0;
for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
need_recv_cnt += rcsv_block_len[recv_block_start + j];
}
// send and recv at same time
linkers_->SendRecv(target, input + rcsv_block_start[send_block_start], send_size, target, output, need_recv_cnt);
// reduce
reducer(output, input + rcsv_block_start[recv_block_start], type_size, need_recv_cnt);
}
}
// send result back to remain workers
if (rank_ < 2 * remain) {
if (rank_ % 2 != 0) {
linkers_->Send(rank_ - 1, input + block_start[rank_ - 1], block_len[rank_ - 1]);
} else {
linkers_->Recv(rank_ + 1, input + block_start[rank_], block_len[rank_]);
}
}
} else { } else {
ReduceScatterRing(input, input_size, type_size, block_start, block_len, output, output_size, reducer);
}
}
void Network::ReduceScatterRecursiveHalving(char* input, comm_size_t input_size, int type_size,
const comm_size_t* block_start, const comm_size_t* block_len, char* output,
comm_size_t, const ReduceFunction& reducer) {
if (!recursive_halving_map_.is_power_of_2) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) {
// send local data to neighbor first
linkers_->Send(recursive_halving_map_.neighbor, input, input_size);
} else if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) {
// receive neighbor data first
int need_recv_cnt = input_size;
linkers_->Recv(recursive_halving_map_.neighbor, output, need_recv_cnt);
// reduce
reducer(output, input, type_size, input_size);
}
}
if (recursive_halving_map_.type != RecursiveHalvingNodeType::Other) {
for (int i = 0; i < recursive_halving_map_.k; ++i) { for (int i = 0; i < recursive_halving_map_.k; ++i) {
// get target // get target
int target = recursive_halving_map_.ranks[i]; int target = recursive_halving_map_.ranks[i];
...@@ -315,8 +279,38 @@ void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, ...@@ -315,8 +279,38 @@ void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size,
reducer(output, input + block_start[recv_block_start], type_size, need_recv_cnt); reducer(output, input + block_start[recv_block_start], type_size, need_recv_cnt);
} }
} }
if (!recursive_halving_map_.is_power_of_2) {
if (recursive_halving_map_.type == RecursiveHalvingNodeType::GroupLeader) {
// send result to neighbor
linkers_->Send(recursive_halving_map_.neighbor,
input + block_start[recursive_halving_map_.neighbor],
block_len[recursive_halving_map_.neighbor]);
} else if (recursive_halving_map_.type == RecursiveHalvingNodeType::Other) {
// receive result from neighbor
int need_recv_cnt = block_len[rank_];
linkers_->Recv(recursive_halving_map_.neighbor, output, need_recv_cnt);
return;
}
}
// copy result // copy result
std::memcpy(output, input + block_start[rank_], block_len[rank_]); std::memcpy(output, input + block_start[rank_], block_len[rank_]);
} }
void Network::ReduceScatterRing(char* input, comm_size_t, int type_size,
const comm_size_t* block_start, const comm_size_t* block_len, char* output,
comm_size_t, const ReduceFunction& reducer) {
const int out_rank = (rank_ + 1) % num_machines_;
const int in_rank = (rank_ - 1 + num_machines_) % num_machines_;
int out_block = in_rank;
int in_block = (in_rank - 1 + num_machines_) % num_machines_;
for (int i = 1; i < num_machines_; ++i) {
linkers_->SendRecv(out_rank, input + block_start[out_block], block_len[out_block],
in_rank, output, block_len[in_block]);
reducer(output, input + block_start[in_block], type_size, block_len[in_block]);
out_block = (out_block - 1 + num_machines_) % num_machines_;
in_block = (in_block - 1 + num_machines_) % num_machines_;
}
std::memcpy(output, input + block_start[rank_], block_len[rank_]);
}
} // namespace LightGBM } // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment