Commit 7aaba32e authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

refine early-stopping feature (#27)

* change Print to PrintAndGetLoss: return loss
parent 563e1464
......@@ -32,7 +32,7 @@ public:
* \param iter Current iteration
* \param score Current prediction score
*/
virtual void Print(int iter, const score_t* score, score_t& loss) const = 0;
virtual score_t PrintAndGetLoss(int iter, const score_t* score) const = 0;
/*!
* \brief Create object of metrics
......
......@@ -239,16 +239,15 @@ void GBDT::UpdateScore(const Tree* tree) {
}
bool GBDT::OutputMetric(int iter) {
score_t train_score_ = 0, test_score_ = 0;
bool ret = false;
// print training metric
for (auto& sub_metric : training_metrics_) {
sub_metric->Print(iter, train_score_updater_->score(), train_score_);
sub_metric->PrintAndGetLoss(iter, train_score_updater_->score());
}
// print validation metric
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
valid_metrics_[i][j]->Print(iter, valid_score_updater_[i]->score(), test_score_);
score_t test_score_ = valid_metrics_[i][j]->PrintAndGetLoss(iter, valid_score_updater_[i]->score());
if (!ret && early_stopping_round_ > 0){
bool the_bigger_the_better_ = valid_metrics_[i][j]->the_bigger_the_better;
if (best_score_[i][j] < 0
......
......@@ -50,7 +50,7 @@ public:
}
}
void Print(int iter, const score_t* score, score_t& loss) const override {
score_t PrintAndGetLoss(int iter, const score_t* score) const override {
score_t sum_loss = 0.0f;
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
if (weights_ == nullptr) {
......@@ -70,11 +70,13 @@ public:
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i];
}
}
loss = sum_loss / sum_weights_;
score_t loss = sum_loss / sum_weights_;
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Stdout("Iteration:%d, %s's %s: %f", iter, name, PointWiseLossCalculator::Name(), loss);
}
return loss;
}
return 0.0f;
}
private:
......@@ -170,7 +172,7 @@ public:
}
}
void Print(int iter, const score_t* score, score_t& loss) const override {
score_t PrintAndGetLoss(int iter, const score_t* score) const override {
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
// get indices sorted by score, descent order
std::vector<data_size_t> sorted_idx;
......@@ -227,11 +229,12 @@ public:
if (sum_pos > 0.0f && sum_pos != sum_weights_) {
auc = accum / (sum_pos *(sum_weights_ - sum_pos));
}
loss = auc;
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Stdout("iteration:%d, %s's %s: %f", iter, name, "auc", loss);
Log::Stdout("iteration:%d, %s's %s: %f", iter, name, "auc", auc);
}
return auc;
}
return 0.0f;
}
private:
......
......@@ -75,7 +75,7 @@ public:
}
}
void Print(int iter, const score_t* score, score_t& loss) const override {
score_t PrintAndGetLoss(int iter, const score_t* score) const override {
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
// some buffers for multi-threading sum up
std::vector<std::vector<double>> result_buffer_;
......@@ -134,11 +134,12 @@ public:
result[j] /= sum_query_weights_;
result_ss << "NDCG@" << eval_at_[j] << ":" << result[j] << "\t";
}
loss = result[0];
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Stdout("Iteration:%d, Test:%s, %s ", iter, name, result_ss.str().c_str());
}
return result[0];
}
return 0.0f;
}
private:
......
......@@ -42,7 +42,7 @@ public:
}
}
void Print(int iter, const score_t* score, score_t& loss) const override {
score_t PrintAndGetLoss(int iter, const score_t* score) const override {
if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
score_t sum_loss = 0.0;
if (weights_ == nullptr) {
......@@ -58,11 +58,13 @@ public:
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i];
}
}
loss = PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_);
score_t loss = PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_);
if (output_freq_ > 0 && iter % output_freq_ == 0){
Log::Stdout("Iteration:%d, %s's %s : %f", iter, name, PointWiseLossCalculator::Name(), loss);
}
return loss;
}
return 0.0f;
}
inline static score_t AverageLoss(score_t sum_loss, score_t sum_weights) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment