Commit 7d4b6d44 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

add early-stopping feature (#21)

* a python regression example

* add early-stopping

* fix bugs

* remove not needed files
parent b028ac5f
......@@ -120,6 +120,7 @@ public:
struct MetricConfig: public ConfigBase {
public:
virtual ~MetricConfig() {}
int early_stopping_round = 0;
int output_freq = 1;
double sigmoid = 1;
bool is_provide_training_metric = false;
......@@ -155,6 +156,7 @@ public:
double bagging_fraction = 1.0;
int bagging_seed = 3;
int bagging_freq = 0;
int early_stopping_round = 0;
void Set(const std::unordered_map<std::string, std::string>& params) override;
};
......
......@@ -32,7 +32,7 @@ public:
* \param iter Current iteration
* \param score Current prediction score
*/
virtual void Print(int iter, const score_t* score) const = 0;
virtual void Print(int iter, const score_t* score, score_t& loss) const = 0;
/*!
* \brief Create object of metrics
......@@ -40,6 +40,9 @@ public:
* \param config Config for metric
*/
static Metric* CreateMetric(const std::string& type, const MetricConfig& config);
bool the_bigger_the_better = false;
int early_stopping_round_ = 0;
};
/*!
......
......@@ -22,6 +22,7 @@ GBDT::GBDT(const BoostingConfig* config)
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) {
max_feature_idx_ = 0;
gbdt_config_ = dynamic_cast<const GBDTConfig*>(config);
early_stopping_round_ = gbdt_config_->early_stopping_round;
}
GBDT::~GBDT() {
......@@ -92,8 +93,12 @@ void GBDT::AddDataset(const Dataset* valid_data,
// for a validation dataset, we need its score and metric
valid_score_updater_.push_back(new ScoreUpdater(valid_data));
valid_metrics_.emplace_back();
best_iter_.emplace_back();
best_score_.emplace_back();
for (const auto& metric : valid_metrics) {
valid_metrics_.back().push_back(metric);
best_iter_.back().push_back(0);
best_score_.back().push_back(-1);
}
}
......@@ -180,7 +185,7 @@ void GBDT::Train() {
UpdateScore(new_tree);
UpdateScoreOutOfBag(new_tree);
// print message for metric
OutputMetric(iter + 1);
if (OutputMetric(iter + 1)) return;
// add model
models_.push_back(new_tree);
// save model to file per iteration
......@@ -209,17 +214,32 @@ void GBDT::UpdateScore(const Tree* tree) {
}
}
void GBDT::OutputMetric(int iter) {
bool GBDT::OutputMetric(int iter) {
score_t train_score_ = 0, test_score_ = 0;
bool ret = false;
// print training metric
for (auto& sub_metric : training_metrics_) {
sub_metric->Print(iter, train_score_updater_->score());
sub_metric->Print(iter, train_score_updater_->score(), train_score_);
}
// print validation metric
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (auto& sub_metric : valid_metrics_[i]) {
sub_metric->Print(iter, valid_score_updater_[i]->score());
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
valid_metrics_[i][j]->Print(iter, valid_score_updater_[i]->score(), test_score_);
if (!ret && early_stopping_round_ > 0){
bool the_bigger_the_better_ = valid_metrics_[i][j]->the_bigger_the_better;
if (best_score_[i][j] < 0
|| (!the_bigger_the_better_ && test_score_ < best_score_[i][j])
|| ( the_bigger_the_better_ && test_score_ > best_score_[i][j])){
best_score_[i][j] = test_score_;
best_iter_[i][j] = iter;
}
else {
if (iter - best_iter_[i][j] >= early_stopping_round_) ret = true;
}
}
}
}
return ret;
}
void GBDT::Boosting() {
......@@ -303,7 +323,7 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) {
}
}
Log::Stdout("Loaded %d modles\n", models_.size());
Log::Stdout("Loaded %d models\n", models_.size());
}
double GBDT::PredictRaw(const double* value) const {
......
......@@ -110,8 +110,10 @@ private:
* \brief Print Metric result of current iteration
* \param iter Current interation
*/
void OutputMetric(int iter);
bool OutputMetric(int iter);
int early_stopping_round_;
/*! \brief Pointer to training data */
const Dataset* train_data_;
/*! \brief Config of gbdt */
......@@ -128,6 +130,9 @@ private:
std::vector<ScoreUpdater*> valid_score_updater_;
/*! \brief Metric for validation data */
std::vector<std::vector<const Metric*>> valid_metrics_;
/*! \brief Best score(s) for early stopping */
std::vector<std::vector<int>> best_iter_;
std::vector<std::vector<score_t>> best_score_;
/*! \brief Trained models(trees) */
std::vector<Tree*> models_;
/*! \brief Max feature index of training data*/
......
......@@ -167,6 +167,7 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "early_stopping_round", &early_stopping_round);
GetInt(params, "metric_freq", &output_freq);
CHECK(output_freq >= 0);
GetDouble(params, "sigmoid", &sigmoid);
......@@ -219,6 +220,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK(bagging_fraction > 0.0 && bagging_fraction <= 1.0);
GetDouble(params, "learning_rate", &learning_rate);
CHECK(learning_rate > 0.0);
GetInt(params, "early_stopping_round", &early_stopping_round);
CHECK(early_stopping_round >= 0);
}
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
......
......@@ -209,7 +209,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_, is_enable_sparse_));
} else {
// if feature is trival(only 1 bin), free spaces
Log::Stdout("Warning: feture %d only contains one value, will ignore it", i);
Log::Stdout("Warning: feature %d only contains one value, will ignore it", i);
delete bin_mappers[i];
}
}
......
......@@ -18,7 +18,9 @@ template<typename PointWiseLossCalculator>
class BinaryMetric: public Metric {
public:
explicit BinaryMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq;
the_bigger_the_better = false;
sigmoid_ = static_cast<score_t>(config.sigmoid);
if (sigmoid_ <= 0.0f) {
Log::Stderr("sigmoid param %f should greater than zero", sigmoid_);
......@@ -48,9 +50,9 @@ public:
}
}
void Print(int iter, const score_t* score) const override {
void Print(int iter, const score_t* score, score_t& loss) const override {
score_t sum_loss = 0.0f;
if (output_freq_ > 0 && iter % output_freq_ == 0) {
if (early_stopping_round_ > 0 || output_freq_ > 0 && iter % output_freq_ == 0) {
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
......@@ -68,7 +70,10 @@ public:
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i];
}
}
Log::Stdout("Iteration:%d, %s's %s: %f", iter, name, PointWiseLossCalculator::Name(), sum_loss / sum_weights_);
loss = sum_loss / sum_weights_;
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Stdout("Iteration:%d, %s's %s: %f", iter, name, PointWiseLossCalculator::Name(), loss);
}
}
}
......@@ -139,7 +144,9 @@ public:
class AUCMetric: public Metric {
public:
explicit AUCMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq;
the_bigger_the_better = true;
}
virtual ~AUCMetric() {
......@@ -163,8 +170,8 @@ public:
}
}
void Print(int iter, const score_t* score) const override {
if (output_freq_ > 0 && iter % output_freq_ == 0) {
void Print(int iter, const score_t* score, score_t& loss) const override {
if (early_stopping_round_ > 0 || output_freq_ > 0 && iter % output_freq_ == 0) {
// get indices sorted by score, descent order
std::vector<data_size_t> sorted_idx;
for (data_size_t i = 0; i < num_data_; ++i) {
......@@ -220,7 +227,10 @@ public:
if (sum_pos > 0.0f && sum_pos != sum_weights_) {
auc = accum / (sum_pos *(sum_weights_ - sum_pos));
}
Log::Stdout("iteration:%d, %s's %s: %f", iter, name, "auc", auc);
loss = auc;
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Stdout("iteration:%d, %s's %s: %f", iter, name, "auc", loss);
}
}
}
......
......@@ -16,7 +16,9 @@ namespace LightGBM {
class NDCGMetric:public Metric {
public:
explicit NDCGMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq;
the_bigger_the_better = true;
// get eval position
for (auto k : config.eval_at) {
eval_at_.push_back(static_cast<data_size_t>(k));
......@@ -73,8 +75,8 @@ public:
}
}
void Print(int iter, const score_t* score) const override {
if (output_freq_ > 0 && iter % output_freq_ == 0) {
void Print(int iter, const score_t* score, score_t& loss) const override {
if (early_stopping_round_ > 0 || output_freq_ > 0 && iter % output_freq_ == 0) {
// some buffers for multi-threading sum up
std::vector<std::vector<double>> result_buffer_;
for (int i = 0; i < num_threads_; ++i) {
......@@ -132,7 +134,10 @@ public:
result[j] /= sum_query_weights_;
result_ss << "NDCG@" << eval_at_[j] << ":" << result[j] << "\t";
}
Log::Stdout("Iteration:%d, Test:%s, %s ", iter, name, result_ss.str().c_str());
loss = result[0];
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Stdout("Iteration:%d, Test:%s, %s ", iter, name, result_ss.str().c_str());
}
}
}
......
......@@ -16,7 +16,9 @@ template<typename PointWiseLossCalculator>
class RegressionMetric: public Metric {
public:
explicit RegressionMetric(const MetricConfig& config) {
early_stopping_round_ = config.early_stopping_round;
output_freq_ = config.output_freq;
the_bigger_the_better = false;
}
virtual ~RegressionMetric() {
......@@ -39,9 +41,9 @@ public:
}
}
}
void Print(int iter, const score_t* score) const override {
if (output_freq_ > 0 && iter % output_freq_ == 0) {
void Print(int iter, const score_t* score, score_t& loss) const override {
if (early_stopping_round_ > 0 || output_freq_ > 0 && iter % output_freq_ == 0) {
score_t sum_loss = 0.0;
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
......@@ -56,7 +58,10 @@ public:
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i];
}
}
Log::Stdout("Iteration:%d, %s's %s : %f", iter, name, PointWiseLossCalculator::Name(), PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_));
loss = PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_);
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Stdout("Iteration:%d, %s's %s : %f", iter, name, PointWiseLossCalculator::Name(), loss);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment