Commit ecc8b8cd authored by Guolin Ke's avatar Guolin Ke
Browse files

slight reduce communication cost in parallel tree learner.

parent 6c4a9750
...@@ -233,14 +233,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const ...@@ -233,14 +233,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
} }
// sync global best info // sync global best info
std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo)); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split);
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo),
output_buffer_.data(), &SplitInfo::MaxReducer);
std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// set best split // set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
......
...@@ -60,14 +60,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con ...@@ -60,14 +60,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
} }
// sync global best info // sync global best info
std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo)); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split);
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo),
output_buffer_.data(), &SplitInfo::MaxReducer);
// copy back
std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// update best split // update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) { if (this->larger_leaf_splits_->LeafIndex() >= 0) {
......
...@@ -125,7 +125,7 @@ protected: ...@@ -125,7 +125,7 @@ protected:
* \param splits All splits from local voting * \param splits All splits from local voting
* \param out Result of gobal voting, only store feature indices * \param out Result of gobal voting, only store feature indices
*/ */
void GlobalVoting(int leaf_idx, const std::vector<SplitInfo>& splits, void GlobalVoting(int leaf_idx, const std::vector<LightSplitInfo>& splits,
std::vector<int>* out); std::vector<int>* out);
/*! /*!
* \brief Copy local histgram to buffer * \brief Copy local histgram to buffer
...@@ -180,6 +180,32 @@ private: ...@@ -180,6 +180,32 @@ private:
std::vector<FeatureMetainfo> feature_metas_; std::vector<FeatureMetainfo> feature_metas_;
}; };
// To-do: reduce the communication cost by using bitset to communicate.
inline void SyncUpGlobalBestSplit(char* input_buffer_, char* output_buffer_, SplitInfo* smaller_best_split, SplitInfo* larger_best_split) {
// sync global best info
int size = SplitInfo::Size();
smaller_best_split->CopyTo(input_buffer_);
larger_best_split->CopyTo(input_buffer_ + size);
Network::Allreduce(input_buffer_, size * 2, size, output_buffer_,
[&size] (const char* src, char* dst, int len) {
int used_size = 0;
LightSplitInfo p1, p2;
while (used_size < len) {
p1.CopyFrom(src);
p2.CopyFrom(dst);
if (p1 > p2) {
std::memcpy(dst, src, size);
}
src += size;
dst += size;
used_size += size;
}
});
// copy back
smaller_best_split->CopyFrom(output_buffer_);
larger_best_split->CopyFrom(output_buffer_ + size);
}
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_ #endif // LightGBM_TREELEARNER_PARALLEL_TREE_LEARNER_H_
...@@ -17,34 +17,86 @@ namespace LightGBM { ...@@ -17,34 +17,86 @@ namespace LightGBM {
struct SplitInfo { struct SplitInfo {
public: public:
/*! \brief Feature index */ /*! \brief Feature index */
int feature; int feature = -1;
/*! \brief Split threshold */ /*! \brief Split threshold */
uint32_t threshold; uint32_t threshold = 0;
/*! \brief True if default split is left */ /*! \brief Left number of data after split */
bool default_left; data_size_t left_count = 0;
/*! \brief Right number of data after split */
data_size_t right_count = 0;
/*! \brief Left output after split */ /*! \brief Left output after split */
double left_output; double left_output = 0.0;
/*! \brief Right output after split */ /*! \brief Right output after split */
double right_output; double right_output = 0.0;
/*! \brief Split gain */ /*! \brief Split gain */
double gain; double gain = kMinScore;
/*! \brief Left number of data after split */
data_size_t left_count;
/*! \brief Right number of data after split */
data_size_t right_count;
/*! \brief Left sum gradient after split */ /*! \brief Left sum gradient after split */
double left_sum_gradient; double left_sum_gradient = 0;
/*! \brief Left sum hessian after split */ /*! \brief Left sum hessian after split */
double left_sum_hessian; double left_sum_hessian = 0;
/*! \brief Right sum gradient after split */ /*! \brief Right sum gradient after split */
double right_sum_gradient; double right_sum_gradient = 0;
/*! \brief Right sum hessian after split */ /*! \brief Right sum hessian after split */
double right_sum_hessian; double right_sum_hessian = 0;
/*! \brief True if default split is left */
bool default_left = true;
SplitInfo() { inline static int Size() {
// initialize with -1 and -inf gain return sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2;
feature = -1; }
gain = kMinScore;
inline void CopyTo(char* buffer) const {
std::memcpy(buffer, &feature, sizeof(feature));
buffer += sizeof(feature);
std::memcpy(buffer, &left_count, sizeof(left_count));
buffer += sizeof(left_count);
std::memcpy(buffer, &right_count, sizeof(right_count));
buffer += sizeof(right_count);
std::memcpy(buffer, &gain, sizeof(gain));
buffer += sizeof(gain);
std::memcpy(buffer, &threshold, sizeof(threshold));
buffer += sizeof(threshold);
std::memcpy(buffer, &left_output, sizeof(left_output));
buffer += sizeof(left_output);
std::memcpy(buffer, &right_output, sizeof(right_output));
buffer += sizeof(right_output);
std::memcpy(buffer, &left_sum_gradient, sizeof(left_sum_gradient));
buffer += sizeof(left_sum_gradient);
std::memcpy(buffer, &left_sum_hessian, sizeof(left_sum_hessian));
buffer += sizeof(left_sum_hessian);
std::memcpy(buffer, &right_sum_gradient, sizeof(right_sum_gradient));
buffer += sizeof(right_sum_gradient);
std::memcpy(buffer, &right_sum_hessian, sizeof(right_sum_hessian));
buffer += sizeof(right_sum_hessian);
std::memcpy(buffer, &default_left, sizeof(default_left));
buffer += sizeof(default_left);
}
void CopyFrom(const char* buffer) {
std::memcpy(&feature, buffer, sizeof(feature));
buffer += sizeof(feature);
std::memcpy(&left_count, buffer, sizeof(left_count));
buffer += sizeof(left_count);
std::memcpy(&right_count, buffer, sizeof(right_count));
buffer += sizeof(right_count);
std::memcpy(&gain, buffer, sizeof(gain));
buffer += sizeof(gain);
std::memcpy(&threshold, buffer, sizeof(threshold));
buffer += sizeof(threshold);
std::memcpy(&left_output, buffer, sizeof(left_output));
buffer += sizeof(left_output);
std::memcpy(&right_output, buffer, sizeof(right_output));
buffer += sizeof(right_output);
std::memcpy(&left_sum_gradient, buffer, sizeof(left_sum_gradient));
buffer += sizeof(left_sum_gradient);
std::memcpy(&left_sum_hessian, buffer, sizeof(left_sum_hessian));
buffer += sizeof(left_sum_hessian);
std::memcpy(&right_sum_gradient, buffer, sizeof(right_sum_gradient));
buffer += sizeof(right_sum_gradient);
std::memcpy(&right_sum_hessian, buffer, sizeof(right_sum_hessian));
buffer += sizeof(right_sum_hessian);
std::memcpy(&default_left, buffer, sizeof(default_left));
buffer += sizeof(default_left);
} }
inline void Reset() { inline void Reset() {
...@@ -53,88 +105,160 @@ public: ...@@ -53,88 +105,160 @@ public:
gain = kMinScore; gain = kMinScore;
} }
inline bool operator > (const SplitInfo &si) const; inline bool operator > (const SplitInfo& si) const {
double local_gain = this->gain;
inline bool operator == (const SplitInfo &si) const; double other_gain = si.gain;
// replace nan with -inf
inline static void MaxReducer(const char* src, char* dst, int len) { if (local_gain == NAN) {
const int type_size = sizeof(SplitInfo); local_gain = kMinScore;
int used_size = 0; }
const SplitInfo* p1; // replace nan with -inf
SplitInfo* p2; if (other_gain == NAN) {
while (used_size < len) { other_gain = kMinScore;
p1 = reinterpret_cast<const SplitInfo*>(src); }
p2 = reinterpret_cast<SplitInfo*>(dst); int local_feature = this->feature;
if (*p1 > *p2) { int other_feature = si.feature;
// copy // replace -1 with max int
std::memcpy(dst, src, type_size); if (local_feature == -1) {
} local_feature = INT32_MAX;
src += type_size; }
dst += type_size; // replace -1 with max int
used_size += type_size; if (other_feature == -1) {
other_feature = INT32_MAX;
}
if (local_gain != other_gain) {
return local_gain > other_gain;
} else {
// if same gain, use smaller feature
return local_feature < other_feature;
} }
} }
};
inline bool operator == (const SplitInfo& si) const {
double local_gain = this->gain;
double other_gain = si.gain;
// replace nan with -inf
if (local_gain == NAN) {
local_gain = kMinScore;
}
// replace nan with -inf
if (other_gain == NAN) {
other_gain = kMinScore;
}
int local_feature = this->feature;
int other_feature = si.feature;
// replace -1 with max int
if (local_feature == -1) {
local_feature = INT32_MAX;
}
// replace -1 with max int
if (other_feature == -1) {
other_feature = INT32_MAX;
}
if (local_gain != other_gain) {
return local_gain == other_gain;
} else {
// if same gain, use smaller feature
return local_feature == other_feature;
}
}
};
inline bool SplitInfo::operator > (const SplitInfo& si) const { struct LightSplitInfo {
double local_gain = this->gain; public:
double other_gain = si.gain; /*! \brief Feature index */
// replace nan with -inf int feature = -1;
if (local_gain == NAN) { /*! \brief Split gain */
local_gain = kMinScore; double gain = kMinScore;
} /*! \brief Left number of data after split */
// replace nan with -inf data_size_t left_count = 0;
if (other_gain == NAN) { /*! \brief Right number of data after split */
other_gain = kMinScore; data_size_t right_count = 0;
}
int local_feature = this->feature; inline void Reset() {
int other_feature = si.feature; // initialize with -1 and -inf gain
// replace -1 with max int feature = -1;
if (local_feature == -1) { gain = kMinScore;
local_feature = INT32_MAX;
}
// replace -1 with max int
if (other_feature == -1) {
other_feature = INT32_MAX;
}
if (local_gain != other_gain) {
return local_gain > other_gain;
} else {
// if same gain, use smaller feature
return local_feature < other_feature;
}
}
inline bool SplitInfo::operator == (const SplitInfo& si) const {
double local_gain = this->gain;
double other_gain = si.gain;
// replace nan with -inf
if (local_gain == NAN) {
local_gain = kMinScore;
} }
// replace nan with -inf
if (other_gain == NAN) { void CopyFrom(const SplitInfo& other) {
other_gain = kMinScore; feature = other.feature;
gain = other.gain;
left_count = other.left_count;
right_count = other.right_count;
} }
int local_feature = this->feature;
int other_feature = si.feature; void CopyFrom(const char* buffer) {
// replace -1 with max int std::memcpy(&feature, buffer, sizeof(feature));
if (local_feature == -1) { buffer += sizeof(feature);
local_feature = INT32_MAX; std::memcpy(&left_count, buffer, sizeof(left_count));
buffer += sizeof(left_count);
std::memcpy(&right_count, buffer, sizeof(right_count));
buffer += sizeof(right_count);
std::memcpy(&gain, buffer, sizeof(gain));
buffer += sizeof(gain);
} }
// replace -1 with max int
if (other_feature == -1) { inline bool operator > (const LightSplitInfo& si) const {
other_feature = INT32_MAX; double local_gain = this->gain;
double other_gain = si.gain;
// replace nan with -inf
if (local_gain == NAN) {
local_gain = kMinScore;
}
// replace nan with -inf
if (other_gain == NAN) {
other_gain = kMinScore;
}
int local_feature = this->feature;
int other_feature = si.feature;
// replace -1 with max int
if (local_feature == -1) {
local_feature = INT32_MAX;
}
// replace -1 with max int
if (other_feature == -1) {
other_feature = INT32_MAX;
}
if (local_gain != other_gain) {
return local_gain > other_gain;
} else {
// if same gain, use smaller feature
return local_feature < other_feature;
}
} }
if (local_gain != other_gain) {
return local_gain == other_gain; inline bool operator == (const LightSplitInfo& si) const {
} else { double local_gain = this->gain;
// if same gain, use smaller feature double other_gain = si.gain;
return local_feature == other_feature; // replace nan with -inf
if (local_gain == NAN) {
local_gain = kMinScore;
}
// replace nan with -inf
if (other_gain == NAN) {
other_gain = kMinScore;
}
int local_feature = this->feature;
int other_feature = si.feature;
// replace -1 with max int
if (local_feature == -1) {
local_feature = INT32_MAX;
}
// replace -1 with max int
if (other_feature == -1) {
other_feature = INT32_MAX;
}
if (local_gain != other_gain) {
return local_gain == other_gain;
} else {
// if same gain, use smaller feature
return local_feature == other_feature;
}
} }
}
};
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_TREELEARNER_SPLIT_INFO_HPP_ #endif // LightGBM_TREELEARNER_SPLIT_INFO_HPP_
...@@ -33,7 +33,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -33,7 +33,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
} }
} }
// calculate buffer size // calculate buffer size
size_t buffer_size = 2 * top_k_ * std::max(max_bin * sizeof(HistogramBinEntry), sizeof(SplitInfo) * num_machines_); size_t buffer_size = 2 * top_k_ * std::max(max_bin * sizeof(HistogramBinEntry), sizeof(LightSplitInfo) * num_machines_);
// left and right on same time, so need double size // left and right on same time, so need double size
input_buffer_.resize(buffer_size); input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size); output_buffer_.resize(buffer_size);
...@@ -162,14 +162,14 @@ bool VotingParallelTreeLearner<TREELEARNER_T>::BeforeFindBestSplit(const Tree* t ...@@ -162,14 +162,14 @@ bool VotingParallelTreeLearner<TREELEARNER_T>::BeforeFindBestSplit(const Tree* t
} }
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::GlobalVoting(int leaf_idx, const std::vector<SplitInfo>& splits, std::vector<int>* out) { void VotingParallelTreeLearner<TREELEARNER_T>::GlobalVoting(int leaf_idx, const std::vector<LightSplitInfo>& splits, std::vector<int>* out) {
out->clear(); out->clear();
if (leaf_idx < 0) { if (leaf_idx < 0) {
return; return;
} }
// get mean number on machines // get mean number on machines
score_t mean_num_data = GetGlobalDataCountInLeaf(leaf_idx) / static_cast<score_t>(num_machines_); score_t mean_num_data = GetGlobalDataCountInLeaf(leaf_idx) / static_cast<score_t>(num_machines_);
std::vector<SplitInfo> feature_best_split(this->num_features_, SplitInfo()); std::vector<LightSplitInfo> feature_best_split(this->num_features_, LightSplitInfo());
for (auto & split : splits) { for (auto & split : splits) {
int fid = split.feature; int fid = split.feature;
if (fid < 0) { if (fid < 0) {
...@@ -183,8 +183,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::GlobalVoting(int leaf_idx, const ...@@ -183,8 +183,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::GlobalVoting(int leaf_idx, const
} }
} }
// get top k // get top k
std::vector<SplitInfo> top_k_splits; std::vector<LightSplitInfo> top_k_splits;
ArrayArgs<SplitInfo>::MaxK(feature_best_split, top_k_, &top_k_splits); ArrayArgs<LightSplitInfo>::MaxK(feature_best_split, top_k_, &top_k_splits);
for (auto& split : top_k_splits) { for (auto& split : top_k_splits) {
if (split.gain == kMinScore || split.feature == -1) { if (split.gain == kMinScore || split.feature == -1) {
continue; continue;
...@@ -318,27 +318,35 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -318,27 +318,35 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
// local voting // local voting
ArrayArgs<SplitInfo>::MaxK(smaller_bestsplit_per_features, top_k_, &smaller_top_k_splits); ArrayArgs<SplitInfo>::MaxK(smaller_bestsplit_per_features, top_k_, &smaller_top_k_splits);
ArrayArgs<SplitInfo>::MaxK(larger_bestsplit_per_features, top_k_, &larger_top_k_splits); ArrayArgs<SplitInfo>::MaxK(larger_bestsplit_per_features, top_k_, &larger_top_k_splits);
std::vector<LightSplitInfo> smaller_top_k_light_splits(top_k_);
std::vector<LightSplitInfo> larger_top_k_light_splits(top_k_);
for (int i = 0; i < top_k_; ++i) {
smaller_top_k_light_splits[i].CopyFrom(smaller_top_k_splits[i]);
larger_top_k_light_splits[i].CopyFrom(larger_top_k_splits[i]);
}
// gather // gather
int offset = 0; int offset = 0;
for (int i = 0; i < top_k_; ++i) { for (int i = 0; i < top_k_; ++i) {
std::memcpy(input_buffer_.data() + offset, &smaller_top_k_splits[i], sizeof(SplitInfo)); std::memcpy(input_buffer_.data() + offset, &smaller_top_k_light_splits[i], sizeof(LightSplitInfo));
offset += sizeof(SplitInfo); offset += sizeof(LightSplitInfo);
std::memcpy(input_buffer_.data() + offset, &larger_top_k_splits[i], sizeof(SplitInfo)); std::memcpy(input_buffer_.data() + offset, &larger_top_k_light_splits[i], sizeof(LightSplitInfo));
offset += sizeof(SplitInfo); offset += sizeof(LightSplitInfo);
} }
Network::Allgather(input_buffer_.data(), offset, output_buffer_.data()); Network::Allgather(input_buffer_.data(), offset, output_buffer_.data());
// get all top-k from all machines // get all top-k from all machines
std::vector<SplitInfo> smaller_top_k_splits_global; std::vector<LightSplitInfo> smaller_top_k_splits_global;
std::vector<SplitInfo> larger_top_k_splits_global; std::vector<LightSplitInfo> larger_top_k_splits_global;
offset = 0; offset = 0;
for (int i = 0; i < num_machines_; ++i) { for (int i = 0; i < num_machines_; ++i) {
for (int j = 0; j < top_k_; ++j) { for (int j = 0; j < top_k_; ++j) {
smaller_top_k_splits_global.push_back(SplitInfo()); smaller_top_k_splits_global.push_back(LightSplitInfo());
std::memcpy(&smaller_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(SplitInfo)); std::memcpy(&smaller_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(LightSplitInfo));
offset += sizeof(SplitInfo); offset += sizeof(LightSplitInfo);
larger_top_k_splits_global.push_back(SplitInfo()); larger_top_k_splits_global.push_back(LightSplitInfo());
std::memcpy(&larger_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(SplitInfo)); std::memcpy(&larger_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(LightSplitInfo));
offset += sizeof(SplitInfo); offset += sizeof(LightSplitInfo);
} }
} }
// global voting // global voting
...@@ -434,13 +442,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -434,13 +442,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()]; larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
} }
// sync global best info // sync global best info
std::memcpy(input_buffer_.data(), &smaller_best_split, sizeof(SplitInfo)); SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split);
std::memcpy(input_buffer_.data() + sizeof(SplitInfo), &larger_best_split, sizeof(SplitInfo));
Network::Allreduce(input_buffer_.data(), sizeof(SplitInfo) * 2, sizeof(SplitInfo), output_buffer_.data(), &SplitInfo::MaxReducer);
std::memcpy(&smaller_best_split, output_buffer_.data(), sizeof(SplitInfo));
std::memcpy(&larger_best_split, output_buffer_.data() + sizeof(SplitInfo), sizeof(SplitInfo));
// copy back // copy back
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split; this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment