"python-package/vscode:/vscode.git/clone" did not exist on "465d1262eb1d8eb3cfa7cc505140c035e52c8118"
Unverified Commit 8ead7cc1 authored by Ilya Matiach's avatar Ilya Matiach Committed by GitHub
Browse files

memory corruption fix for distributed data parallel version before SyncUpGlobalBestSplit (#3110)

* memory corruption fix for distributed data parallel version before SyncUpGlobalBestSplit

* updated based on comments

* updated voting and feature parallel based on comments

* fixing macos failure

* rename variable
parent 51b84df8
...@@ -26,8 +26,14 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo ...@@ -26,8 +26,14 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo
// Get local rank and global machine size // Get local rank and global machine size
rank_ = Network::rank(); rank_ = Network::rank();
num_machines_ = Network::num_machines(); num_machines_ = Network::num_machines();
auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
size_t split_info_size = static_cast<size_t>(SplitInfo::Size(max_cat_threshold) * 2);
size_t histogram_size = static_cast<size_t>(this->train_data_->NumTotalBin() * kHistEntrySize);
// allocate buffer for communication // allocate buffer for communication
size_t buffer_size = this->train_data_->NumTotalBin() * kHistEntrySize; size_t buffer_size = std::max(histogram_size, split_info_size);
input_buffer_.resize(buffer_size); input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size); output_buffer_.resize(buffer_size);
......
...@@ -24,8 +24,13 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, ...@@ -24,8 +24,13 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data,
TREELEARNER_T::Init(train_data, is_constant_hessian); TREELEARNER_T::Init(train_data, is_constant_hessian);
rank_ = Network::rank(); rank_ = Network::rank();
num_machines_ = Network::num_machines(); num_machines_ = Network::num_machines();
input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2); auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
int split_info_size = SplitInfo::Size(max_cat_threshold) * 2;
input_buffer_.resize(split_info_size);
output_buffer_.resize(split_info_size);
} }
......
...@@ -49,7 +49,7 @@ struct SplitInfo { ...@@ -49,7 +49,7 @@ struct SplitInfo {
bool default_left = true; bool default_left = true;
int8_t monotone_type = 0; int8_t monotone_type = 0;
inline static int Size(int max_cat_threshold) { inline static int Size(int max_cat_threshold) {
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t); return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
} }
inline void CopyTo(char* buffer) const { inline void CopyTo(char* buffer) const {
......
...@@ -37,6 +37,10 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -37,6 +37,10 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
} }
// calculate buffer size // calculate buffer size
size_t buffer_size = 2 * top_k_ * std::max(max_bin * kHistEntrySize, sizeof(LightSplitInfo) * num_machines_); size_t buffer_size = 2 * top_k_ * std::max(max_bin * kHistEntrySize, sizeof(LightSplitInfo) * num_machines_);
auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
size_t split_info_size = static_cast<size_t>(SplitInfo::Size(max_cat_threshold) * 2);
buffer_size = std::max(buffer_size, split_info_size);
// left and right on same time, so need double size // left and right on same time, so need double size
input_buffer_.resize(buffer_size); input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size); output_buffer_.resize(buffer_size);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment