Unverified Commit 70fc45b0 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

code refactoring: cost effective gradient boosting (#2407)

* refactoring

* fix style

* fix style

* Update cost_effective_gradient_boosting.hpp

* Update serial_tree_learner.cpp

* Update serial_tree_learner.h

* fix style

* update vc project

* Update cost_effective_gradient_boosting.hpp
parent e771538f
/*!
* Copyright (c) 2019 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_TREELEARNER_COST_EFFECTIVE_GRADIENT_BOOSTING_HPP_
#define LIGHTGBM_TREELEARNER_COST_EFFECTIVE_GRADIENT_BOOSTING_HPP_
#include <LightGBM/config.h>
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <vector>
#include "data_partition.hpp"
#include "serial_tree_learner.h"
#include "split_info.hpp"
namespace LightGBM {
class CostEfficientGradientBoosting {
public:
explicit CostEfficientGradientBoosting(const SerialTreeLearner* tree_learner):tree_learner_(tree_learner) {
}
static bool IsEnable(const Config* config) {
if (config->cegb_tradeoff >= 1.0f && config->cegb_penalty_split <= 0.0f
&& config->cegb_penalty_feature_coupled.empty() && config->cegb_penalty_feature_lazy.empty()) {
return false;
} else {
return true;
}
}
void Init() {
auto train_data = tree_learner_->train_data_;
splits_per_leaf_.resize(static_cast<size_t>(tree_learner_->config_->num_leaves) * train_data->num_features());
is_feature_used_in_split_.clear();
is_feature_used_in_split_.resize(train_data->num_features());
if (!tree_learner_->config_->cegb_penalty_feature_coupled.empty()
&& tree_learner_->config_->cegb_penalty_feature_coupled.size() != static_cast<size_t>(train_data->num_total_features())) {
Log::Fatal("cegb_penalty_feature_coupled should be the same size as feature number.");
}
if (!tree_learner_->config_->cegb_penalty_feature_lazy.empty()) {
if (tree_learner_->config_->cegb_penalty_feature_lazy.size() != static_cast<size_t>(train_data->num_total_features())) {
Log::Fatal("cegb_penalty_feature_lazy should be the same size as feature number.");
}
feature_used_in_data_ = Common::EmptyBitset(train_data->num_features() * tree_learner_->num_data_);
}
}
double DetlaGain(int feature_index, int real_fidx, int leaf_index, int num_data_in_leaf, SplitInfo split_info) {
auto config = tree_learner_->config_;
double delta = config->cegb_tradeoff * config->cegb_penalty_split * num_data_in_leaf;
if (!config->cegb_penalty_feature_coupled.empty() && !is_feature_used_in_split_[feature_index]) {
delta += config->cegb_tradeoff * config->cegb_penalty_feature_coupled[real_fidx];
}
if (!config->cegb_penalty_feature_lazy.empty()) {
delta += config->cegb_tradeoff * CalculateOndemandCosts(feature_index, real_fidx, leaf_index);
}
splits_per_leaf_[static_cast<size_t>(leaf_index) * tree_learner_->train_data_->num_features() + feature_index] = split_info;
return delta;
}
void UpdateLeafBestSplits(Tree* tree, int best_leaf, const SplitInfo* best_split_info, std::vector<SplitInfo>* best_split_per_leaf) {
auto config = tree_learner_->config_;
auto train_data = tree_learner_->train_data_;
const int inner_feature_index = train_data->InnerFeatureIndex(best_split_info->feature);
if (!config->cegb_penalty_feature_coupled.empty() && !is_feature_used_in_split_[inner_feature_index]) {
is_feature_used_in_split_[inner_feature_index] = true;
for (int i = 0; i < tree->num_leaves(); ++i) {
if (i == best_leaf) continue;
auto split = &splits_per_leaf_[static_cast<size_t>(i) * train_data->num_features() + inner_feature_index];
split->gain += config->cegb_tradeoff * config->cegb_penalty_feature_coupled[best_split_info->feature];
if (*split > best_split_per_leaf->at(i))
best_split_per_leaf->at(i) = *split;
}
}
if (!config->cegb_penalty_feature_lazy.empty()) {
data_size_t cnt_leaf_data = 0;
auto tmp_idx = tree_learner_->data_partition_->GetIndexOnLeaf(best_leaf, &cnt_leaf_data);
for (data_size_t i_input = 0; i_input < cnt_leaf_data; ++i_input) {
int real_idx = tmp_idx[i_input];
Common::InsertBitset(&feature_used_in_data_, train_data->num_data() * inner_feature_index + real_idx);
}
}
}
private:
double CalculateOndemandCosts(int feature_index, int real_fidx, int leaf_index) const {
if (tree_learner_->config_->cegb_penalty_feature_lazy.empty()) {
return 0.0f;
}
auto train_data = tree_learner_->train_data_;
double penalty = tree_learner_->config_->cegb_penalty_feature_lazy[real_fidx];
double total = 0.0f;
data_size_t cnt_leaf_data = 0;
auto tmp_idx = tree_learner_->data_partition_->GetIndexOnLeaf(leaf_index, &cnt_leaf_data);
for (data_size_t i_input = 0; i_input < cnt_leaf_data; ++i_input) {
int real_idx = tmp_idx[i_input];
if (Common::FindInBitset(feature_used_in_data_.data(), train_data->num_data() * train_data->num_features(), train_data->num_data() * feature_index + real_idx)) {
continue;
}
total += penalty;
}
return total;
}
const SerialTreeLearner* tree_learner_;
std::vector<SplitInfo> splits_per_leaf_;
std::vector<bool> is_feature_used_in_split_;
std::vector<uint32_t> feature_used_in_data_;
};
} // namespace LightGBM
#endif // LIGHTGBM_TREELEARNER_COST_EFFECTIVE_GRADIENT_BOOSTING_HPP_
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "cost_effective_gradient_boosting.hpp"
namespace LightGBM { namespace LightGBM {
#ifdef TIMETAG #ifdef TIMETAG
...@@ -69,8 +71,6 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian ...@@ -69,8 +71,6 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
histogram_pool_.DynamicChangeSize(train_data_, config_, max_cache_size, config_->num_leaves); histogram_pool_.DynamicChangeSize(train_data_, config_, max_cache_size, config_->num_leaves);
// push split information for all leaves // push split information for all leaves
best_split_per_leaf_.resize(config_->num_leaves); best_split_per_leaf_.resize(config_->num_leaves);
splits_per_leaf_.resize(config_->num_leaves*train_data_->num_features());
// get ordered bin // get ordered bin
train_data_->CreateOrderedBins(&ordered_bins_); train_data_->CreateOrderedBins(&ordered_bins_);
...@@ -104,15 +104,9 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian ...@@ -104,15 +104,9 @@ void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian
} }
} }
Log::Info("Number of data: %d, number of used features: %d", num_data_, num_features_); Log::Info("Number of data: %d, number of used features: %d", num_data_, num_features_);
is_feature_used_in_split_.clear(); if (CostEfficientGradientBoosting::IsEnable(config_)) {
is_feature_used_in_split_.resize(train_data->num_features()); cegb_.reset(new CostEfficientGradientBoosting(this));
cegb_->Init();
if (!config_->cegb_penalty_feature_coupled.empty()) {
CHECK(config_->cegb_penalty_feature_coupled.size() == static_cast<size_t>(train_data_->num_total_features()));
}
if (!config_->cegb_penalty_feature_lazy.empty()) {
CHECK(config_->cegb_penalty_feature_lazy.size() == static_cast<size_t>(train_data_->num_total_features()));
feature_used_in_data = Common::EmptyBitset(train_data->num_features() * num_data_);
} }
} }
...@@ -139,6 +133,9 @@ void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) { ...@@ -139,6 +133,9 @@ void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) {
is_data_in_leaf_.resize(num_data_); is_data_in_leaf_.resize(num_data_);
std::fill(is_data_in_leaf_.begin(), is_data_in_leaf_.end(), static_cast<char>(0)); std::fill(is_data_in_leaf_.begin(), is_data_in_leaf_.end(), static_cast<char>(0));
} }
if (cegb_ != nullptr) {
cegb_->Init();
}
} }
void SerialTreeLearner::ResetConfig(const Config* config) { void SerialTreeLearner::ResetConfig(const Config* config) {
...@@ -166,8 +163,11 @@ void SerialTreeLearner::ResetConfig(const Config* config) { ...@@ -166,8 +163,11 @@ void SerialTreeLearner::ResetConfig(const Config* config) {
} else { } else {
config_ = config; config_ = config;
} }
histogram_pool_.ResetConfig(config_); histogram_pool_.ResetConfig(config_);
if (CostEfficientGradientBoosting::IsEnable(config_)) {
cegb_.reset(new CostEfficientGradientBoosting(this));
cegb_->Init();
}
} }
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian, const Json& forced_split_json) { Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian, const Json& forced_split_json) {
...@@ -521,28 +521,6 @@ void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_featur ...@@ -521,28 +521,6 @@ void SerialTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_featur
#endif #endif
} }
double SerialTreeLearner::CalculateOndemandCosts(int feature_index, int leaf_index) {
if (config_->cegb_penalty_feature_lazy.empty())
return 0.0f;
double penalty = config_->cegb_penalty_feature_lazy[feature_index];
const int inner_fidx = train_data_->InnerFeatureIndex(feature_index);
double total = 0.0f;
data_size_t cnt_leaf_data = 0;
auto tmp_idx = data_partition_->GetIndexOnLeaf(leaf_index, &cnt_leaf_data);
for (data_size_t i_input = 0; i_input < cnt_leaf_data; ++i_input) {
int real_idx = tmp_idx[i_input];
if (Common::FindInBitset(feature_used_in_data.data(), train_data_->num_data()*train_data_->num_features(), train_data_->num_data() * inner_fidx + real_idx))
continue;
total += penalty;
}
return total;
}
void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) { void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
...@@ -576,14 +554,9 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& ...@@ -576,14 +554,9 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
smaller_leaf_splits_->max_constraint(), smaller_leaf_splits_->max_constraint(),
&smaller_split); &smaller_split);
smaller_split.feature = real_fidx; smaller_split.feature = real_fidx;
smaller_split.gain -= config_->cegb_tradeoff * config_->cegb_penalty_split * smaller_leaf_splits_->num_data_in_leaf(); if (cegb_ != nullptr) {
if (!config_->cegb_penalty_feature_coupled.empty() && !is_feature_used_in_split_[feature_index]) { smaller_split.gain -= cegb_->DetlaGain(feature_index, real_fidx, smaller_leaf_splits_->LeafIndex(), smaller_leaf_splits_->num_data_in_leaf(), smaller_split);
smaller_split.gain -= config_->cegb_tradeoff * config_->cegb_penalty_feature_coupled[real_fidx];
} }
if (!config_->cegb_penalty_feature_lazy.empty()) {
smaller_split.gain -= config_->cegb_tradeoff * CalculateOndemandCosts(real_fidx, smaller_leaf_splits_->LeafIndex());
}
splits_per_leaf_[smaller_leaf_splits_->LeafIndex()*train_data_->num_features() + feature_index] = smaller_split;
if (smaller_split > smaller_best[tid] && smaller_node_used_features[feature_index]) { if (smaller_split > smaller_best[tid] && smaller_node_used_features[feature_index]) {
smaller_best[tid] = smaller_split; smaller_best[tid] = smaller_split;
} }
...@@ -607,14 +580,9 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& ...@@ -607,14 +580,9 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
larger_leaf_splits_->max_constraint(), larger_leaf_splits_->max_constraint(),
&larger_split); &larger_split);
larger_split.feature = real_fidx; larger_split.feature = real_fidx;
larger_split.gain -= config_->cegb_tradeoff * config_->cegb_penalty_split * larger_leaf_splits_->num_data_in_leaf(); if (cegb_ != nullptr) {
if (!config_->cegb_penalty_feature_coupled.empty() && !is_feature_used_in_split_[feature_index]) { larger_split.gain -= cegb_->DetlaGain(feature_index, real_fidx, larger_leaf_splits_->LeafIndex(), larger_leaf_splits_->num_data_in_leaf(), larger_split);
larger_split.gain -= config_->cegb_tradeoff * config_->cegb_penalty_feature_coupled[real_fidx];
}
if (!config_->cegb_penalty_feature_lazy.empty()) {
larger_split.gain -= config_->cegb_tradeoff*CalculateOndemandCosts(real_fidx, larger_leaf_splits_->LeafIndex());
} }
splits_per_leaf_[larger_leaf_splits_->LeafIndex()*train_data_->num_features() + feature_index] = larger_split;
if (larger_split > larger_best[tid] && larger_node_used_features[feature_index]) { if (larger_split > larger_best[tid] && larger_node_used_features[feature_index]) {
larger_best[tid] = larger_split; larger_best[tid] = larger_split;
} }
...@@ -803,26 +771,9 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json ...@@ -803,26 +771,9 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf) { void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf) {
const SplitInfo& best_split_info = best_split_per_leaf_[best_leaf]; const SplitInfo& best_split_info = best_split_per_leaf_[best_leaf];
const int inner_feature_index = train_data_->InnerFeatureIndex(best_split_info.feature); const int inner_feature_index = train_data_->InnerFeatureIndex(best_split_info.feature);
if (!config_->cegb_penalty_feature_coupled.empty() && !is_feature_used_in_split_[inner_feature_index]) { if (cegb_ != nullptr) {
is_feature_used_in_split_[inner_feature_index] = true; cegb_->UpdateLeafBestSplits(tree, best_leaf, &best_split_info, &best_split_per_leaf_);
for (int i = 0; i < tree->num_leaves(); ++i) {
if (i == best_leaf) continue;
auto split = &splits_per_leaf_[i*train_data_->num_features() + inner_feature_index];
split->gain += config_->cegb_tradeoff*config_->cegb_penalty_feature_coupled[best_split_info.feature];
if (*split > best_split_per_leaf_[i])
best_split_per_leaf_[i] = *split;
}
}
if (!config_->cegb_penalty_feature_lazy.empty()) {
data_size_t cnt_leaf_data = 0;
auto tmp_idx = data_partition_->GetIndexOnLeaf(best_leaf, &cnt_leaf_data);
for (data_size_t i_input = 0; i_input < cnt_leaf_data; ++i_input) {
int real_idx = tmp_idx[i_input];
Common::InsertBitset(&feature_used_in_data, train_data_->num_data() * inner_feature_index + real_idx);
}
} }
// left = parent // left = parent
*left_leaf = best_leaf; *left_leaf = best_leaf;
bool is_numerical_split = train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin; bool is_numerical_split = train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin;
......
...@@ -32,12 +32,14 @@ ...@@ -32,12 +32,14 @@
using namespace json11; using namespace json11;
namespace LightGBM { namespace LightGBM {
/*! \brief forward declaration */
class CostEfficientGradientBoosting;
/*! /*!
* \brief Used for learning a tree by single machine * \brief Used for learning a tree by single machine
*/ */
class SerialTreeLearner: public TreeLearner { class SerialTreeLearner: public TreeLearner {
public: public:
friend CostEfficientGradientBoosting;
explicit SerialTreeLearner(const Config* config); explicit SerialTreeLearner(const Config* config);
~SerialTreeLearner(); ~SerialTreeLearner();
...@@ -116,8 +118,6 @@ class SerialTreeLearner: public TreeLearner { ...@@ -116,8 +118,6 @@ class SerialTreeLearner: public TreeLearner {
*/ */
inline virtual data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const; inline virtual data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const;
double CalculateOndemandCosts(int feature_index, int leaf_index);
/*! \brief number of data */ /*! \brief number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief number of features */ /*! \brief number of features */
...@@ -179,9 +179,7 @@ class SerialTreeLearner: public TreeLearner { ...@@ -179,9 +179,7 @@ class SerialTreeLearner: public TreeLearner {
int num_threads_; int num_threads_;
std::vector<int> ordered_bin_indices_; std::vector<int> ordered_bin_indices_;
bool is_constant_hessian_; bool is_constant_hessian_;
std::unique_ptr<CostEfficientGradientBoosting> cegb_;
std::vector<bool> is_feature_used_in_split_;
std::vector<uint32_t> feature_used_in_data;
}; };
inline data_size_t SerialTreeLearner::GetGlobalDataCountInLeaf(int leaf_idx) const { inline data_size_t SerialTreeLearner::GetGlobalDataCountInLeaf(int leaf_idx) const {
......
...@@ -240,6 +240,7 @@ ...@@ -240,6 +240,7 @@
<ClInclude Include="..\src\objective\regression_objective.hpp" /> <ClInclude Include="..\src\objective\regression_objective.hpp" />
<ClInclude Include="..\src\objective\multiclass_objective.hpp" /> <ClInclude Include="..\src\objective\multiclass_objective.hpp" />
<ClInclude Include="..\src\objective\xentropy_objective.hpp" /> <ClInclude Include="..\src\objective\xentropy_objective.hpp" />
<ClInclude Include="..\src\treelearner\cost_effective_gradient_boosting.hpp" />
<ClInclude Include="..\src\treelearner\data_partition.hpp" /> <ClInclude Include="..\src\treelearner\data_partition.hpp" />
<ClInclude Include="..\src\treelearner\feature_histogram.hpp" /> <ClInclude Include="..\src\treelearner\feature_histogram.hpp" />
<ClInclude Include="..\src\treelearner\leaf_splits.hpp" /> <ClInclude Include="..\src\treelearner\leaf_splits.hpp" />
...@@ -283,4 +284,4 @@ ...@@ -283,4 +284,4 @@
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets"> <ImportGroup Label="ExtensionTargets">
</ImportGroup> </ImportGroup>
</Project> </Project>
\ No newline at end of file
...@@ -207,6 +207,9 @@ ...@@ -207,6 +207,9 @@
<ClInclude Include="..\include\LightGBM\json11.hpp"> <ClInclude Include="..\include\LightGBM\json11.hpp">
<Filter>include\LightGBM\utils</Filter> <Filter>include\LightGBM\utils</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\treelearner\cost_effective_gradient_boosting.hpp">
<Filter>src\treelearner</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\src\application\application.cpp"> <ClCompile Include="..\src\application\application.cpp">
...@@ -303,4 +306,4 @@ ...@@ -303,4 +306,4 @@
<Filter>src\io</Filter> <Filter>src\io</Filter>
</ClCompile> </ClCompile>
</ItemGroup> </ItemGroup>
</Project> </Project>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment