"vscode:/vscode.git/clone" did not exist on "8869e1cb67b6273fada8778a34ef90bbc87f954f"
Commit 6d0eae0c authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for Boosting.

parent 3db907cc
...@@ -60,12 +60,11 @@ public: ...@@ -60,12 +60,11 @@ public:
/*! /*!
* \brief Training logic * \brief Training logic
* \param gradient nullptr for using default objective, otherwise use self-defined boosting * \param gradients nullptr for using default objective, otherwise use self-defined boosting
* \param hessian nullptr for using default objective, otherwise use self-defined boosting * \param hessians nullptr for using default objective, otherwise use self-defined boosting
* \param is_eval true if need evaluation or early stop * \return True if cannot train anymore
* \return True if meet early stopping or cannot boosting
*/ */
virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0; virtual bool TrainOneIter(const score_t* gradients, const score_t* hessians) = 0;
/*! /*!
* \brief Rollback one iteration * \brief Rollback one iteration
...@@ -77,10 +76,6 @@ public: ...@@ -77,10 +76,6 @@ public:
*/ */
virtual int GetCurrentIteration() const = 0; virtual int GetCurrentIteration() const = 0;
/*!
* \brief Eval metrics and check is met early stopping or not
*/
virtual bool EvalAndCheckEarlyStopping() = 0;
/*! /*!
* \brief Get evaluation result at data_idx data * \brief Get evaluation result at data_idx data
* \param data_idx 0: training data, 1: 1st validation data * \param data_idx 0: training data, 1: 1st validation data
...@@ -101,6 +96,7 @@ public: ...@@ -101,6 +96,7 @@ public:
* \return out_len length of returned score * \return out_len length of returned score
*/ */
virtual int64_t GetNumPredictAt(int data_idx) const = 0; virtual int64_t GetNumPredictAt(int data_idx) const = 0;
/*! /*!
* \brief Get prediction result at data_idx data * \brief Get prediction result at data_idx data
* \param data_idx 0: training data, 1: 1st validation data * \param data_idx 0: training data, 1: 1st validation data
...@@ -115,7 +111,7 @@ public: ...@@ -115,7 +111,7 @@ public:
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param output Prediction result for this record * \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated. * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
*/ */
virtual void PredictRaw(const double* features, double* output, virtual void PredictRaw(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0; const PredictionEarlyStopInstance* early_stop) const = 0;
...@@ -124,7 +120,7 @@ public: ...@@ -124,7 +120,7 @@ public:
* \brief Prediction for one record, sigmoid transformation will be used if needed * \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param output Prediction result for this record * \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated. * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
*/ */
virtual void Predict(const double* features, double* output, virtual void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0; const PredictionEarlyStopInstance* early_stop) const = 0;
...@@ -141,7 +137,7 @@ public: ...@@ -141,7 +137,7 @@ public:
* \brief Feature contributions for the model's prediction of one record * \brief Feature contributions for the model's prediction of one record
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param output Prediction result for this record * \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated. * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
*/ */
virtual void PredictContrib(const double* features, double* output, virtual void PredictContrib(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0; const PredictionEarlyStopInstance* early_stop) const = 0;
...@@ -224,10 +220,10 @@ public: ...@@ -224,10 +220,10 @@ public:
virtual int NumberOfTotalModel() const = 0; virtual int NumberOfTotalModel() const = 0;
/*! /*!
* \brief Get number of trees per iteration * \brief Get number of models per iteration
* \return Number of trees per iteration * \return Number of models per iteration
*/ */
virtual int NumTreePerIteration() const = 0; virtual int NumModelPerIteration() const = 0;
/*! /*!
* \brief Get number of classes * \brief Get number of classes
...@@ -275,6 +271,12 @@ public: ...@@ -275,6 +271,12 @@ public:
}; };
class GBDTBase : public Boosting {
public:
virtual double GetLeafValue(int tree_idx, int leaf_idx) const = 0;
virtual void SetLeafValue(int tree_idx, int leaf_idx, double val) = 0;
};
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_BOOSTING_H_ #endif // LightGBM_BOOSTING_H_
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
virtual bool SkipEmptyClass() const { return false; } virtual bool SkipEmptyClass() const { return false; }
virtual int NumTreePerIteration() const { return 1; } virtual int NumModelPerIteration() const { return 1; }
virtual int NumPredictOneRow() const { return 1; } virtual int NumPredictOneRow() const { return 1; }
......
...@@ -167,6 +167,15 @@ public: ...@@ -167,6 +167,15 @@ public:
shrinkage_ *= rate; shrinkage_ *= rate;
} }
inline void AddBias(double val) {
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] = val + leaf_value_[i];
}
// force to 1.0
shrinkage_ = 1.0f;
}
inline void AsConstantTree(double val) { inline void AsConstantTree(double val) {
num_leaves_ = 1; num_leaves_ = 1;
shrinkage_ = 1.0f; shrinkage_ = 1.0f;
......
...@@ -48,21 +48,20 @@ public: ...@@ -48,21 +48,20 @@ public:
/*! /*!
* \brief one training iteration * \brief one training iteration
*/ */
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override { bool TrainOneIter(const score_t* gradient, const score_t* hessian) override {
is_update_score_cur_iter_ = false; is_update_score_cur_iter_ = false;
GBDT::TrainOneIter(gradient, hessian, false); bool ret = GBDT::TrainOneIter(gradient, hessian);
if (ret) {
return ret;
}
// normalize // normalize
Normalize(); Normalize();
if (!gbdt_config_->uniform_drop) { if (!gbdt_config_->uniform_drop) {
tree_weight_.push_back(shrinkage_rate_); tree_weight_.push_back(shrinkage_rate_);
sum_weight_ += shrinkage_rate_; sum_weight_ += shrinkage_rate_;
} }
if (is_eval) {
return EvalAndCheckEarlyStopping();
} else {
return false; return false;
} }
}
/*! /*!
* \brief Get current training score * \brief Get current training score
......
This diff is collapsed.
...@@ -15,19 +15,23 @@ ...@@ -15,19 +15,23 @@
#include <mutex> #include <mutex>
namespace LightGBM { namespace LightGBM {
/*! /*!
* \brief GBDT algorithm implementation. including Training, prediction, bagging. * \brief GBDT algorithm implementation. including Training, prediction, bagging.
*/ */
class GBDT: public Boosting { class GBDT: public GBDTBase {
public: public:
/*! /*!
* \brief Constructor * \brief Constructor
*/ */
GBDT(); GBDT();
/*! /*!
* \brief Destructor * \brief Destructor
*/ */
~GBDT(); ~GBDT();
/*! /*!
* \brief Initialization logic * \brief Initialization logic
* \param gbdt_config Config for boosting * \param gbdt_config Config for boosting
...@@ -36,12 +40,10 @@ public: ...@@ -36,12 +40,10 @@ public:
* \param training_metrics Training metrics * \param training_metrics Training metrics
*/ */
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* objective_function, void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) const std::vector<const Metric*>& training_metrics) override;
override;
/*! /*!
* \brief Merge model from other boosting object * \brief Merge model from other boosting object. Will insert to the front of current boosting object
Will insert to the front of current boosting object
* \param other * \param other
*/ */
void MergeFrom(const Boosting* other) override { void MergeFrom(const Boosting* other) override {
...@@ -63,10 +65,21 @@ public: ...@@ -63,10 +65,21 @@ public:
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
} }
/*!
* \brief Reset the training data
* \param train_data New Training data
* \param objective_function Training objective function
* \param training_metrics Training metrics
*/
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override; const std::vector<const Metric*>& training_metrics) override;
void ResetConfig(const BoostingConfig* config) override; /*!
* \brief Reset Boosting Config
* \param gbdt_config Config for boosting
*/
void ResetConfig(const BoostingConfig* gbdt_config) override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
* \param valid_data Validation dataset * \param valid_data Validation dataset
...@@ -75,26 +88,35 @@ public: ...@@ -75,26 +88,35 @@ public:
void AddValidDataset(const Dataset* valid_data, void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) override; const std::vector<const Metric*>& valid_metrics) override;
/*!
* \brief Perform a full training procedure
* \param snapshot_freq frequence of snapshot
* \param model_output_path path of model file
*/
void Train(int snapshot_freq, const std::string& model_output_path) override; void Train(int snapshot_freq, const std::string& model_output_path) override;
/*! /*!
* \brief Training logic * \brief Training logic
* \param gradient nullptr for using default objective, otherwise use self-defined boosting * \param gradients nullptr for using default objective, otherwise use self-defined boosting
* \param hessian nullptr for using default objective, otherwise use self-defined boosting * \param hessians nullptr for using default objective, otherwise use self-defined boosting
* \param is_eval true if need evaluation or early stop * \return True if cannot train any more
* \return True if meet early stopping or cannot boosting
*/ */
virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override; virtual bool TrainOneIter(const score_t* gradients, const score_t* hessians) override;
/*! /*!
* \brief Rollback one iteration * \brief Rollback one iteration
*/ */
void RollbackOneIter() override; void RollbackOneIter() override;
/*!
* \brief Get current iteration
*/
int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_tree_per_iteration_; } int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_tree_per_iteration_; }
bool EvalAndCheckEarlyStopping() override; /*!
* \brief Can use early stopping for prediction or not
* \return True if cannot use early stopping for prediction
*/
bool NeedAccuratePrediction() const override { bool NeedAccuratePrediction() const override {
if (objective_function_ == nullptr) { if (objective_function_ == nullptr) {
return true; return true;
...@@ -117,6 +139,11 @@ public: ...@@ -117,6 +139,11 @@ public:
*/ */
virtual const double* GetTrainingScore(int64_t* out_len) override; virtual const double* GetTrainingScore(int64_t* out_len) override;
/*!
* \brief Get size of prediction at data_idx data
* \param data_idx 0: training data, 1: 1st validation data
* \return The size of prediction
*/
virtual int64_t GetNumPredictAt(int data_idx) const override { virtual int64_t GetNumPredictAt(int data_idx) const override {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size())); CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
data_size_t num_data = train_data_->num_data(); data_size_t num_data = train_data_->num_data();
...@@ -125,6 +152,7 @@ public: ...@@ -125,6 +152,7 @@ public:
} }
return num_data * num_class_; return num_data * num_class_;
} }
/*! /*!
* \brief Get prediction result at data_idx data * \brief Get prediction result at data_idx data
* \param data_idx 0: training data, 1: 1st validation data * \param data_idx 0: training data, 1: 1st validation data
...@@ -133,6 +161,13 @@ public: ...@@ -133,6 +161,13 @@ public:
*/ */
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) override; void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) override;
/*!
* \brief Get number of prediction for one data
* \param num_iteration number of used iterations
* \param is_pred_leaf True if predicting leaf index
* \param is_pred_contrib True if predicting feature contribution
* \return number of prediction
*/
inline int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override { inline int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override {
int num_preb_in_one_row = num_class_; int num_preb_in_one_row = num_class_;
if (is_pred_leaf) { if (is_pred_leaf) {
...@@ -237,7 +272,7 @@ public: ...@@ -237,7 +272,7 @@ public:
* \brief Get number of tree per iteration * \brief Get number of tree per iteration
* \return number of tree per iteration * \return number of tree per iteration
*/ */
inline int NumTreePerIteration() const override { return num_tree_per_iteration_; } inline int NumModelPerIteration() const override { return num_tree_per_iteration_; }
/*! /*!
* \brief Get number of classes * \brief Get number of classes
...@@ -248,17 +283,17 @@ public: ...@@ -248,17 +283,17 @@ public:
inline void InitPredict(int num_iteration) override { inline void InitPredict(int num_iteration) override {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
if (num_iteration > 0) { if (num_iteration > 0) {
num_iteration_for_pred_ = std::min(num_iteration + (boost_from_average_ ? 1 : 0), num_iteration_for_pred_); num_iteration_for_pred_ = std::min(num_iteration, num_iteration_for_pred_);
} }
} }
inline double GetLeafValue(int tree_idx, int leaf_idx) const { inline double GetLeafValue(int tree_idx, int leaf_idx) const override {
CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size()); CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size());
CHECK(leaf_idx >= 0 && leaf_idx < models_[tree_idx]->num_leaves()); CHECK(leaf_idx >= 0 && leaf_idx < models_[tree_idx]->num_leaves());
return models_[tree_idx]->LeafOutput(leaf_idx); return models_[tree_idx]->LeafOutput(leaf_idx);
} }
inline void SetLeafValue(int tree_idx, int leaf_idx, double val) { inline void SetLeafValue(int tree_idx, int leaf_idx, double val) override {
CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size()); CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size());
CHECK(leaf_idx >= 0 && leaf_idx < models_[tree_idx]->num_leaves()); CHECK(leaf_idx >= 0 && leaf_idx < models_[tree_idx]->num_leaves());
models_[tree_idx]->SetLeafOutput(leaf_idx, val); models_[tree_idx]->SetLeafOutput(leaf_idx, val);
...@@ -270,7 +305,17 @@ public: ...@@ -270,7 +305,17 @@ public:
virtual const char* SubModelName() const override { return "tree"; } virtual const char* SubModelName() const override { return "tree"; }
protected: protected:
/*!
* \brief Print eval result and check early stopping
*/
bool EvalAndCheckEarlyStopping();
/*!
* \brief reset config for bagging
*/
void ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset); void ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset);
/*! /*!
* \brief Implement bagging logic * \brief Implement bagging logic
* \param iter Current interation * \param iter Current interation
...@@ -285,17 +330,12 @@ protected: ...@@ -285,17 +330,12 @@ protected:
* \return count of left size * \return count of left size
*/ */
data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer); data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer);
/*!
* \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training
* \param tree Trained tree of this iteration
* \param cur_tree_id Current tree for multiclass training
*/
virtual void UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id);
/*! /*!
* \brief calculate the object function * \brief calculate the object function
*/ */
virtual void Boosting(); virtual void Boosting();
/*! /*!
* \brief updating score after tree was trained * \brief updating score after tree was trained
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
...@@ -303,7 +343,12 @@ protected: ...@@ -303,7 +343,12 @@ protected:
*/ */
virtual void UpdateScore(const Tree* tree, const int cur_tree_id); virtual void UpdateScore(const Tree* tree, const int cur_tree_id);
/*!
* \brief eval results for one metric
*/
virtual std::vector<double> EvalOneMetric(const Metric* metric, const double* score) const; virtual std::vector<double> EvalOneMetric(const Metric* metric, const double* score) const;
/*! /*!
* \brief Print metric result of current iteration * \brief Print metric result of current iteration
* \param iter Current interation * \param iter Current interation
...@@ -311,6 +356,8 @@ protected: ...@@ -311,6 +356,8 @@ protected:
*/ */
std::string OutputMetric(int iter); std::string OutputMetric(int iter);
double BoostFromAverage();
/*! \brief current iteration */ /*! \brief current iteration */
int iter_; int iter_;
/*! \brief Pointer to training data */ /*! \brief Pointer to training data */
...@@ -382,7 +429,6 @@ protected: ...@@ -382,7 +429,6 @@ protected:
std::vector<data_size_t> right_write_pos_buf_; std::vector<data_size_t> right_write_pos_buf_;
std::unique_ptr<Dataset> tmp_subset_; std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_; bool is_use_subset_;
bool boost_from_average_;
std::vector<bool> class_need_train_; std::vector<bool> class_need_train_;
std::vector<double> class_default_output_; std::vector<double> class_default_output_;
bool is_constant_hessian_; bool is_constant_hessian_;
......
#include "gbdt.h"
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <sstream>
#include <string>
#include <vector>
namespace LightGBM {
std::string GBDT::DumpModel(int num_iteration) const {
std::stringstream str_buf;
str_buf << "{";
str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << std::endl;
str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
str_buf << "\"feature_names\":[\""
<< Common::Join(feature_names_, "\",\"") << "\"],"
<< std::endl;
str_buf << "\"tree_info\":[";
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << ",";
}
str_buf << "{";
str_buf << "\"tree_index\":" << i << ",";
str_buf << models_[i]->ToJSON();
str_buf << "}";
}
str_buf << "]" << std::endl;
str_buf << "}" << std::endl;
return str_buf.str();
}
std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream str_buf;
str_buf << "#include \"gbdt.h\"" << std::endl;
str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
str_buf << "#include <LightGBM/metric.h>" << std::endl;
str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
str_buf << "#include <ctime>" << std::endl;
str_buf << "#include <sstream>" << std::endl;
str_buf << "#include <chrono>" << std::endl;
str_buf << "#include <string>" << std::endl;
str_buf << "#include <vector>" << std::endl;
str_buf << "#include <utility>" << std::endl;
str_buf << "namespace LightGBM {" << std::endl;
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}
// PredictRaw
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, false) << std::endl;
}
str_buf << "double (*PredictTreePtr[])(const double*) = { ";
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << " , ";
}
str_buf << "PredictTree" << i;
}
str_buf << " };" << std::endl << std::endl;
std::stringstream pred_str_buf;
pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << std::endl;
pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;
str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "}" << std::endl;
str_buf << std::endl;
// Predict
str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
str_buf << "\t" << "if (average_output_) {" << std::endl;
str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << std::endl;
str_buf << "\t\t" << "}" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "\t" << "else if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << std::endl;
// PredictLeafIndex
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, true) << std::endl;
}
str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << " , ";
}
str_buf << "PredictTree" << i << "Leaf";
}
str_buf << " };" << std::endl << std::endl;
str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << "} // namespace LightGBM" << std::endl;
return str_buf.str();
}
bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
std::ifstream ifs(filename);
if (ifs.good()) {
std::string origin((std::istreambuf_iterator<char>(ifs)),
(std::istreambuf_iterator<char>()));
output_file.open(filename);
output_file << "#define USE_HARD_CODE 0" << std::endl;
output_file << "#ifndef USE_HARD_CODE" << std::endl;
output_file << origin << std::endl;
output_file << "#else" << std::endl;
output_file << ModelToIfElse(num_iteration);
output_file << "#endif" << std::endl;
} else {
output_file.open(filename);
output_file << ModelToIfElse(num_iteration);
}
ifs.close();
output_file.close();
return (bool)output_file;
}
std::string GBDT::SaveModelToString(int num_iteration) const {
std::stringstream ss;
// output model type
ss << SubModelName() << std::endl;
// output number of class
ss << "num_class=" << num_class_ << std::endl;
ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << std::endl;
// output label index
ss << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx
ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
// output objective
if (objective_function_ != nullptr) {
ss << "objective=" << objective_function_->ToString() << std::endl;
}
if (average_output_) {
ss << "average_output" << std::endl;
}
ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
ss << std::endl;
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}
// output tree models
for (int i = 0; i < num_used_model; ++i) {
ss << "Tree=" << i << std::endl;
ss << models_[i]->ToString() << std::endl;
}
// store the importance first
std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
if (feature_importances_int > 0) {
pairs.emplace_back(feature_importances_int, feature_names_[i]);
}
}
// sort the importance
std::sort(pairs.begin(), pairs.end(),
[](const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
ss << std::endl << "feature importances:" << std::endl;
for (size_t i = 0; i < pairs.size(); ++i) {
ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
}
return ss.str();
}
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
output_file.open(filename);
output_file << SaveModelToString(num_iteration);
output_file.close();
return (bool)output_file;
}
bool GBDT::LoadModelFromString(const std::string& model_str) {
// use serialized string to restore this object
models_.clear();
std::vector<std::string> lines = Common::SplitLines(model_str.c_str());
// get number of classes
auto line = Common::FindFromLines(lines, "num_class=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
} else {
Log::Fatal("Model file doesn't specify the number of classes");
return false;
}
line = Common::FindFromLines(lines, "num_tree_per_iteration=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_tree_per_iteration_);
} else {
num_tree_per_iteration_ = num_class_;
}
// get index of label
line = Common::FindFromLines(lines, "label_index=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
} else {
Log::Fatal("Model file doesn't specify the label index");
return false;
}
// get max_feature_idx first
line = Common::FindFromLines(lines, "max_feature_idx=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
} else {
Log::Fatal("Model file doesn't specify max_feature_idx");
return false;
}
// get average_output
line = Common::FindFromLines(lines, "average_output");
if (line.size() > 0) {
average_output_ = true;
}
// get feature names
line = Common::FindFromLines(lines, "feature_names=");
if (line.size() > 0) {
feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), ' ');
if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_names");
return false;
}
} else {
Log::Fatal("Model file doesn't contain feature names");
return false;
}
line = Common::FindFromLines(lines, "feature_infos=");
if (line.size() > 0) {
feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), ' ');
if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_infos");
return false;
}
} else {
Log::Fatal("Model file doesn't contain feature infos");
return false;
}
line = Common::FindFromLines(lines, "objective=");
if (line.size() > 0) {
auto str = Common::Split(line.c_str(), '=')[1];
loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
objective_function_ = loaded_objective_.get();
}
// get tree models
size_t i = 0;
while (i < lines.size()) {
size_t find_pos = lines[i].find("Tree=");
if (find_pos != std::string::npos) {
++i;
int start = static_cast<int>(i);
while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
int end = static_cast<int>(i);
std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
models_.emplace_back(new Tree(tree_str));
} else {
++i;
}
}
Log::Info("Finished loading %d models", models_.size());
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
num_init_iteration_ = num_iteration_for_pred_;
iter_ = 0;
return true;
}
std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_type) const {
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_iteration += 0;
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}
std::vector<double> feature_importances(max_feature_idx_ + 1, 0.0);
if (importance_type == 0) {
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
}
}
}
} else if (importance_type == 1) {
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
}
}
}
} else {
Log::Fatal("Unknown importance type: only support split=0 and gain=1.");
}
return feature_importances;
}
} // namespace LightGBM
...@@ -86,33 +86,33 @@ public: ...@@ -86,33 +86,33 @@ public:
GetGradients(tmp_score.data(), gradients_.data(), hessians_.data()); GetGradients(tmp_score.data(), gradients_.data(), hessians_.data());
} }
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override { bool TrainOneIter(const score_t* gradients, const score_t* hessians) override {
// bagging logic // bagging logic
Bagging(iter_); Bagging(iter_);
if (gradient == nullptr || hessian == nullptr) { if (gradients == nullptr || hessians == nullptr) {
gradient = gradients_.data(); gradients = gradients_.data();
hessian = hessians_.data(); hessians = hessians_.data();
} }
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
// get sub gradients
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
std::unique_ptr<Tree> new_tree(new Tree(2));
if (class_need_train_[cur_tree_id]) {
size_t bias = static_cast<size_t>(cur_tree_id)* num_data_; size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
// cannot multi-threading here.
auto grad = gradients + bias;
auto hess = hessians + bias;
// need to copy gradients for bagging subset.
if (is_use_subset_ && bag_data_cnt_ < num_data_) {
for (int i = 0; i < bag_data_cnt_; ++i) { for (int i = 0; i < bag_data_cnt_; ++i) {
tmp_grad_[bias + i] = gradient[bias + bag_data_indices_[i]]; tmp_grad_[bias + i] = grad[bag_data_indices_[i]];
tmp_hess_[bias + i] = hessian[bias + bag_data_indices_[i]]; tmp_hess_[bias + i] = hess[bag_data_indices_[i]];
} }
} grad = tmp_grad_.data() + bias;
gradient = tmp_grad_.data(); hess = tmp_hess_.data() + bias;
hessian = tmp_hess_.data();
} }
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_));
std::unique_ptr<Tree> new_tree(new Tree(2));
if (class_need_train_[cur_tree_id]) {
size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
new_tree.reset(
tree_learner_->Train(gradient + bias, hessian + bias, is_constant_hessian_));
} }
if (new_tree->num_leaves() > 1) { if (new_tree->num_leaves() > 1) {
...@@ -120,7 +120,6 @@ public: ...@@ -120,7 +120,6 @@ public:
MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_)); MultiplyScore(cur_tree_id, (iter_ + num_init_iteration_));
ConvertTreeOutput(new_tree.get()); ConvertTreeOutput(new_tree.get());
UpdateScore(new_tree.get(), cur_tree_id); UpdateScore(new_tree.get(), cur_tree_id);
UpdateScoreOutOfBag(new_tree.get(), cur_tree_id);
MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1)); MultiplyScore(cur_tree_id, 1.0 / (iter_ + num_init_iteration_ + 1));
} else { } else {
// only add default score one-time // only add default score one-time
...@@ -138,12 +137,8 @@ public: ...@@ -138,12 +137,8 @@ public:
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
++iter_; ++iter_;
if (is_eval) {
return EvalAndCheckEarlyStopping();
} else {
return false; return false;
} }
}
void RollbackOneIter() override { void RollbackOneIter() override {
if (iter_ <= 0) { return; } if (iter_ <= 0) { return; }
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include <functional> #include <functional>
#include "./application/predictor.hpp" #include "./application/predictor.hpp"
#include "./boosting/gbdt.h"
namespace LightGBM { namespace LightGBM {
...@@ -158,12 +157,12 @@ public: ...@@ -158,12 +157,12 @@ public:
bool TrainOneIter() { bool TrainOneIter() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return boosting_->TrainOneIter(nullptr, nullptr, false); return boosting_->TrainOneIter(nullptr, nullptr);
} }
bool TrainOneIter(const float* gradients, const float* hessians) { bool TrainOneIter(const float* gradients, const float* hessians) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return boosting_->TrainOneIter(gradients, hessians, false); return boosting_->TrainOneIter(gradients, hessians);
} }
void RollbackOneIter() { void RollbackOneIter() {
...@@ -253,12 +252,12 @@ public: ...@@ -253,12 +252,12 @@ public:
} }
double GetLeafValue(int tree_idx, int leaf_idx) const { double GetLeafValue(int tree_idx, int leaf_idx) const {
return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx); return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
} }
void SetLeafValue(int tree_idx, int leaf_idx, double val) { void SetLeafValue(int tree_idx, int leaf_idx, double val) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val); dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
} }
int GetEvalCounts() const { int GetEvalCounts() const {
......
...@@ -102,7 +102,15 @@ for (data_size_t i = start; i < end; ++i) {\ ...@@ -102,7 +102,15 @@ for (data_size_t i = start; i < end; ++i) {\
}\ }\
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const { void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
if (num_leaves_ <= 1) { return; } if (num_leaves_ <= 1) {
if (leaf_value_[0] != 0.0f) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
score[i] += leaf_value_[0];
}
}
return;
}
std::vector<uint32_t> default_bins(num_leaves_ - 1); std::vector<uint32_t> default_bins(num_leaves_ - 1);
std::vector<uint32_t> max_bins(num_leaves_ - 1); std::vector<uint32_t> max_bins(num_leaves_ - 1);
for (int i = 0; i < num_leaves_ - 1; ++i) { for (int i = 0; i < num_leaves_ - 1; ++i) {
...@@ -141,7 +149,15 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl ...@@ -141,7 +149,15 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl
void Tree::AddPredictionToScore(const Dataset* data, void Tree::AddPredictionToScore(const Dataset* data,
const data_size_t* used_data_indices, const data_size_t* used_data_indices,
data_size_t num_data, double* score) const { data_size_t num_data, double* score) const {
if (num_leaves_ <= 1) { return; } if (num_leaves_ <= 1) {
if (leaf_value_[0] != 0.0f) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
score[used_data_indices[i]] += leaf_value_[0];
}
}
return;
}
std::vector<uint32_t> default_bins(num_leaves_ - 1); std::vector<uint32_t> default_bins(num_leaves_ - 1);
std::vector<uint32_t> max_bins(num_leaves_ - 1); std::vector<uint32_t> max_bins(num_leaves_ - 1);
for (int i = 0; i < num_leaves_ - 1; ++i) { for (int i = 0; i < num_leaves_ - 1; ++i) {
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
int num_tree_per_iteration = num_class_; int num_tree_per_iteration = num_class_;
int num_pred_per_row = num_class_; int num_pred_per_row = num_class_;
if (objective != nullptr) { if (objective != nullptr) {
num_tree_per_iteration = objective->NumTreePerIteration(); num_tree_per_iteration = objective->NumModelPerIteration();
num_pred_per_row = objective->NumPredictOneRow(); num_pred_per_row = objective->NumPredictOneRow();
} }
if (objective != nullptr) { if (objective != nullptr) {
......
...@@ -114,7 +114,7 @@ public: ...@@ -114,7 +114,7 @@ public:
bool SkipEmptyClass() const override { return true; } bool SkipEmptyClass() const override { return true; }
int NumTreePerIteration() const override { return num_class_; } int NumModelPerIteration() const override { return num_class_; }
int NumPredictOneRow() const override { return num_class_; } int NumPredictOneRow() const override { return num_class_; }
...@@ -206,7 +206,7 @@ public: ...@@ -206,7 +206,7 @@ public:
bool SkipEmptyClass() const override { return true; } bool SkipEmptyClass() const override { return true; }
int NumTreePerIteration() const override { return num_class_; } int NumModelPerIteration() const override { return num_class_; }
int NumPredictOneRow() const override { return num_class_; } int NumPredictOneRow() const override { return num_class_; }
......
...@@ -247,6 +247,7 @@ ...@@ -247,6 +247,7 @@
<ClCompile Include="..\src\application\application.cpp" /> <ClCompile Include="..\src\application\application.cpp" />
<ClCompile Include="..\src\boosting\boosting.cpp" /> <ClCompile Include="..\src\boosting\boosting.cpp" />
<ClCompile Include="..\src\boosting\gbdt.cpp" /> <ClCompile Include="..\src\boosting\gbdt.cpp" />
<ClCompile Include="..\src\boosting\gbdt_model.cpp" />
<ClCompile Include="..\src\boosting\gbdt_prediction.cpp" /> <ClCompile Include="..\src\boosting\gbdt_prediction.cpp" />
<ClCompile Include="..\src\boosting\prediction_early_stop.cpp" /> <ClCompile Include="..\src\boosting\prediction_early_stop.cpp" />
<ClCompile Include="..\src\c_api.cpp" /> <ClCompile Include="..\src\c_api.cpp" />
......
...@@ -278,5 +278,8 @@ ...@@ -278,5 +278,8 @@
<ClCompile Include="..\src\lightgbm_R.cpp"> <ClCompile Include="..\src\lightgbm_R.cpp">
<Filter>src</Filter> <Filter>src</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="..\src\boosting\gbdt_model.cpp">
<Filter>src\boosting</Filter>
</ClCompile>
</ItemGroup> </ItemGroup>
</Project> </Project>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment