Unverified Commit da91c613 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix bug in parallel learning (#2851)

* refix

* fix config

* avoid to rely on config
parent 9c386db1
...@@ -908,7 +908,7 @@ struct Config { ...@@ -908,7 +908,7 @@ struct Config {
size_t file_load_progress_interval_bytes = size_t(10) * 1024 * 1024 * 1024; size_t file_load_progress_interval_bytes = size_t(10) * 1024 * 1024 * 1024;
bool is_parallel = false; bool is_parallel = false;
bool is_parallel_find_bin = false; bool is_data_based_parallel = false;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params); LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params);
static const std::unordered_map<std::string, std::string>& alias_table(); static const std::unordered_map<std::string, std::string>& alias_table();
static const std::unordered_set<std::string>& parameter_set(); static const std::unordered_set<std::string>& parameter_set();
......
...@@ -93,7 +93,7 @@ void Application::LoadData() { ...@@ -93,7 +93,7 @@ void Application::LoadData() {
} }
// sync up random seed for data partition // sync up random seed for data partition
if (config_.is_parallel_find_bin) { if (config_.is_data_based_parallel) {
config_.data_random_seed = Network::GlobalSyncUpByMin(config_.data_random_seed); config_.data_random_seed = Network::GlobalSyncUpByMin(config_.data_random_seed);
} }
...@@ -101,7 +101,7 @@ void Application::LoadData() { ...@@ -101,7 +101,7 @@ void Application::LoadData() {
DatasetLoader dataset_loader(config_, predict_fun, DatasetLoader dataset_loader(config_, predict_fun,
config_.num_class, config_.data.c_str()); config_.num_class, config_.data.c_str());
// load Training data // load Training data
if (config_.is_parallel_find_bin) { if (config_.is_data_based_parallel) {
// load data for parallel training // load data for parallel training
train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.data.c_str(),
Network::rank(), Network::num_machines())); Network::rank(), Network::num_machines()));
......
...@@ -280,10 +280,10 @@ void Config::CheckParamConflict() { ...@@ -280,10 +280,10 @@ void Config::CheckParamConflict() {
} }
if (is_single_tree_learner || tree_learner == std::string("feature")) { if (is_single_tree_learner || tree_learner == std::string("feature")) {
is_parallel_find_bin = false; is_data_based_parallel = false;
} else if (tree_learner == std::string("data") } else if (tree_learner == std::string("data")
|| tree_learner == std::string("voting")) { || tree_learner == std::string("voting")) {
is_parallel_find_bin = true; is_data_based_parallel = true;
if (histogram_pool_size >= 0 if (histogram_pool_size >= 0
&& tree_learner == std::string("data")) { && tree_learner == std::string("data")) {
Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n" Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f).\n"
......
...@@ -241,7 +241,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const ...@@ -241,7 +241,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) { void DataParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
TREELEARNER_T::Split(tree, best_Leaf, left_leaf, right_leaf); this->SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf]; const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
// need update global number of data in leaf // need update global number of data in leaf
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count; global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
......
...@@ -648,89 +648,111 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json ...@@ -648,89 +648,111 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
return result_count; return result_count;
} }
void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf) { void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
Common::FunctionTimer fun_timer("SerialTreeLearner::Split", global_timer); int* right_leaf, bool update_cnt) {
Common::FunctionTimer fun_timer("SerialTreeLearner::SplitInner", global_timer);
SplitInfo& best_split_info = best_split_per_leaf_[best_leaf]; SplitInfo& best_split_info = best_split_per_leaf_[best_leaf];
const int inner_feature_index = train_data_->InnerFeatureIndex(best_split_info.feature); const int inner_feature_index =
train_data_->InnerFeatureIndex(best_split_info.feature);
if (cegb_ != nullptr) { if (cegb_ != nullptr) {
cegb_->UpdateLeafBestSplits(tree, best_leaf, &best_split_info, &best_split_per_leaf_); cegb_->UpdateLeafBestSplits(tree, best_leaf, &best_split_info,
&best_split_per_leaf_);
} }
*left_leaf = best_leaf; *left_leaf = best_leaf;
auto next_leaf_id = tree->NextLeafId(); auto next_leaf_id = tree->NextLeafId();
bool is_numerical_split = train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin; bool is_numerical_split =
train_data_->FeatureBinMapper(inner_feature_index)->bin_type() ==
BinType::NumericalBin;
if (is_numerical_split) { if (is_numerical_split) {
auto threshold_double = train_data_->RealThreshold(inner_feature_index, best_split_info.threshold); auto threshold_double = train_data_->RealThreshold(
inner_feature_index, best_split_info.threshold);
data_partition_->Split(best_leaf, train_data_, inner_feature_index, data_partition_->Split(best_leaf, train_data_, inner_feature_index,
&best_split_info.threshold, 1, best_split_info.default_left, next_leaf_id); &best_split_info.threshold, 1,
best_split_info.left_count = data_partition_->leaf_count(*left_leaf); best_split_info.default_left, next_leaf_id);
best_split_info.right_count = data_partition_->leaf_count(next_leaf_id); if (update_cnt) {
// don't need to update this in data-based parallel model
best_split_info.left_count = data_partition_->leaf_count(*left_leaf);
best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
}
// split tree, will return right leaf // split tree, will return right leaf
*right_leaf = tree->Split(best_leaf, *right_leaf = tree->Split(
inner_feature_index, best_leaf, inner_feature_index, best_split_info.feature,
best_split_info.feature, best_split_info.threshold, threshold_double,
best_split_info.threshold, static_cast<double>(best_split_info.left_output),
threshold_double, static_cast<double>(best_split_info.right_output),
static_cast<double>(best_split_info.left_output), static_cast<data_size_t>(best_split_info.left_count),
static_cast<double>(best_split_info.right_output), static_cast<data_size_t>(best_split_info.right_count),
static_cast<data_size_t>(best_split_info.left_count), static_cast<double>(best_split_info.left_sum_hessian),
static_cast<data_size_t>(best_split_info.right_count), static_cast<double>(best_split_info.right_sum_hessian),
static_cast<double>(best_split_info.left_sum_hessian), static_cast<float>(best_split_info.gain),
static_cast<double>(best_split_info.right_sum_hessian), train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
static_cast<float>(best_split_info.gain), best_split_info.default_left);
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left);
} else { } else {
std::vector<uint32_t> cat_bitset_inner = Common::ConstructBitset(best_split_info.cat_threshold.data(), best_split_info.num_cat_threshold); std::vector<uint32_t> cat_bitset_inner =
Common::ConstructBitset(best_split_info.cat_threshold.data(),
best_split_info.num_cat_threshold);
std::vector<int> threshold_int(best_split_info.num_cat_threshold); std::vector<int> threshold_int(best_split_info.num_cat_threshold);
for (int i = 0; i < best_split_info.num_cat_threshold; ++i) { for (int i = 0; i < best_split_info.num_cat_threshold; ++i) {
threshold_int[i] = static_cast<int>(train_data_->RealThreshold(inner_feature_index, best_split_info.cat_threshold[i])); threshold_int[i] = static_cast<int>(train_data_->RealThreshold(
inner_feature_index, best_split_info.cat_threshold[i]));
} }
std::vector<uint32_t> cat_bitset = Common::ConstructBitset(threshold_int.data(), best_split_info.num_cat_threshold); std::vector<uint32_t> cat_bitset = Common::ConstructBitset(
threshold_int.data(), best_split_info.num_cat_threshold);
data_partition_->Split(best_leaf, train_data_, inner_feature_index, data_partition_->Split(best_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()), best_split_info.default_left, next_leaf_id); cat_bitset_inner.data(),
static_cast<int>(cat_bitset_inner.size()),
best_split_info.left_count = data_partition_->leaf_count(*left_leaf); best_split_info.default_left, next_leaf_id);
best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
if (update_cnt) {
*right_leaf = tree->SplitCategorical(best_leaf, // don't need to update this in data-based parallel model
inner_feature_index, best_split_info.left_count = data_partition_->leaf_count(*left_leaf);
best_split_info.feature, best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
cat_bitset_inner.data(), }
static_cast<int>(cat_bitset_inner.size()),
cat_bitset.data(), *right_leaf = tree->SplitCategorical(
static_cast<int>(cat_bitset.size()), best_leaf, inner_feature_index, best_split_info.feature,
static_cast<double>(best_split_info.left_output), cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
static_cast<double>(best_split_info.right_output), cat_bitset.data(), static_cast<int>(cat_bitset.size()),
static_cast<data_size_t>(best_split_info.left_count), static_cast<double>(best_split_info.left_output),
static_cast<data_size_t>(best_split_info.right_count), static_cast<double>(best_split_info.right_output),
static_cast<double>(best_split_info.left_sum_hessian), static_cast<data_size_t>(best_split_info.left_count),
static_cast<double>(best_split_info.right_sum_hessian), static_cast<data_size_t>(best_split_info.right_count),
static_cast<float>(best_split_info.gain), static_cast<double>(best_split_info.left_sum_hessian),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type()); static_cast<double>(best_split_info.right_sum_hessian),
} static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
#ifdef DEBUG }
#ifdef DEBUG
CHECK(*right_leaf == next_leaf_id); CHECK(*right_leaf == next_leaf_id);
#endif #endif
// init the leaves that used on next iteration // init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) { if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0); CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian); smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian); best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian);
} else { } else {
CHECK_GT(best_split_info.right_count, 0); CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian); smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian); best_split_info.right_sum_gradient,
} best_split_info.right_sum_hessian);
constraints_->UpdateConstraints( larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
is_numerical_split, *left_leaf, *right_leaf, best_split_info.left_sum_gradient,
best_split_info.monotone_type, best_split_info.right_output, best_split_info.left_sum_hessian);
best_split_info.left_output); }
} constraints_->UpdateConstraints(is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type,
best_split_info.right_output,
best_split_info.left_output);
}
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter, void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const { data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
......
...@@ -137,7 +137,12 @@ class SerialTreeLearner: public TreeLearner { ...@@ -137,7 +137,12 @@ class SerialTreeLearner: public TreeLearner {
* \param left_leaf The index of left leaf after splitted. * \param left_leaf The index of left leaf after splitted.
* \param right_leaf The index of right leaf after splitted. * \param right_leaf The index of right leaf after splitted.
*/ */
virtual void Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf); inline virtual void Split(Tree* tree, int best_leaf, int* left_leaf,
int* right_leaf) {
SplitInner(tree, best_leaf, left_leaf, right_leaf, true);
}
void SplitInner(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf, bool update_cnt);
/* Force splits with forced_split_json dict and then return num splits forced.*/ /* Force splits with forced_split_json dict and then return num splits forced.*/
virtual int32_t ForceSplits(Tree* tree, const Json& forced_split_json, int* left_leaf, virtual int32_t ForceSplits(Tree* tree, const Json& forced_split_json, int* left_leaf,
......
...@@ -429,7 +429,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -429,7 +429,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
template <typename TREELEARNER_T> template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) { void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
TREELEARNER_T::Split(tree, best_Leaf, left_leaf, right_leaf); this->SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf]; const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
// set the global number of data for leaves // set the global number of data for leaves
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count; global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment