Unverified Commit 8ead7cc1 authored by Ilya Matiach's avatar Ilya Matiach Committed by GitHub
Browse files

memory corruption fix for distributed data parallel version before SyncUpGlobalBestSplit (#3110)

* memory corruption fix for distributed data parallel version before SyncUpGlobalBestSplit

* updated based on comments

* updated voting and feature parallel based on comments

* fixing macos failure

* rename variable
parent 51b84df8
......@@ -26,8 +26,14 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo
// Get local rank and global machine size
rank_ = Network::rank();
num_machines_ = Network::num_machines();
auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
size_t split_info_size = static_cast<size_t>(SplitInfo::Size(max_cat_threshold) * 2);
size_t histogram_size = static_cast<size_t>(this->train_data_->NumTotalBin() * kHistEntrySize);
// allocate buffer for communication
size_t buffer_size = this->train_data_->NumTotalBin() * kHistEntrySize;
size_t buffer_size = std::max(histogram_size, split_info_size);
input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size);
......
......@@ -24,8 +24,13 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data,
TREELEARNER_T::Init(train_data, is_constant_hessian);
rank_ = Network::rank();
num_machines_ = Network::num_machines();
input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
int split_info_size = SplitInfo::Size(max_cat_threshold) * 2;
input_buffer_.resize(split_info_size);
output_buffer_.resize(split_info_size);
}
......
......@@ -49,7 +49,7 @@ struct SplitInfo {
bool default_left = true;
int8_t monotone_type = 0;
inline static int Size(int max_cat_threshold) {
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
}
inline void CopyTo(char* buffer) const {
......
......@@ -37,6 +37,10 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
}
// calculate buffer size
size_t buffer_size = 2 * top_k_ * std::max(max_bin * kHistEntrySize, sizeof(LightSplitInfo) * num_machines_);
auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
size_t split_info_size = static_cast<size_t>(SplitInfo::Size(max_cat_threshold) * 2);
buffer_size = std::max(buffer_size, split_info_size);
// left and right on same time, so need double size
input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment