Commit f70a0532 authored by Laurae's avatar Laurae Committed by James Lamb
Browse files

[R-package] Fix best_iter and best_score (#2159)

* Callback for NA handling

* lgb.Booster default score => NA

* lgb.cv default best score => NA

* Fix back callback

* lgb.train with booster check at the end

manual tests done: 
* With early stopping + with validation set
* With early stopping + without validation set
* Without early stopping + with validation set
* Without early stopping + without validation set

And with multiple metrics / validation sets.

* lgb.cv with booster check at the end

manual tests done: 
* With early stopping + with validation set
* With early stopping + without validation set
* Without early stopping + with validation set
* Without early stopping + without validation set

And with multiple metrics / validation sets.
parent 2459362a
...@@ -10,7 +10,7 @@ CB_ENV <- R6::R6Class( ...@@ -10,7 +10,7 @@ CB_ENV <- R6::R6Class(
eval_list = list(), eval_list = list(),
eval_err_list = list(), eval_err_list = list(),
best_iter = -1, best_iter = -1,
best_score = -1, best_score = NA,
met_early_stop = FALSE met_early_stop = FALSE
) )
) )
...@@ -360,6 +360,7 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) { ...@@ -360,6 +360,7 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
} }
} }
if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) { if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
# Check if model is not null # Check if model is not null
if (!is.null(env$model)) { if (!is.null(env$model)) {
......
...@@ -5,7 +5,7 @@ Booster <- R6::R6Class( ...@@ -5,7 +5,7 @@ Booster <- R6::R6Class(
public = list( public = list(
best_iter = -1, best_iter = -1,
best_score = -1, best_score = NA,
record_evals = list(), record_evals = list(),
# Finalize will free up the handles # Finalize will free up the handles
......
...@@ -4,7 +4,7 @@ CVBooster <- R6::R6Class( ...@@ -4,7 +4,7 @@ CVBooster <- R6::R6Class(
cloneable = FALSE, cloneable = FALSE,
public = list( public = list(
best_iter = -1, best_iter = -1,
best_score = -1, best_score = NA,
record_evals = list(), record_evals = list(),
boosters = list(), boosters = list(),
initialize = function(x) { initialize = function(x) {
...@@ -305,6 +305,17 @@ lgb.cv <- function(params = list(), ...@@ -305,6 +305,17 @@ lgb.cv <- function(params = list(),
if (env$met_early_stop) break if (env$met_early_stop) break
} }
if (record && is.na(env$best_score)) {
if (env$eval_list[[1]]$higher_better[1] == TRUE) {
cv_booster$best_iter <- unname(which.max(unlist(cv_booster$record_evals[[2]][[1]][[1]])))
cv_booster$best_score <- cv_booster$record_evals[[2]][[1]][[1]][[cv_booster$best_iter]]
} else {
cv_booster$best_iter <- unname(which.min(unlist(cv_booster$record_evals[[2]][[1]][[1]])))
cv_booster$best_score <- cv_booster$record_evals[[2]][[1]][[1]][[cv_booster$best_iter]]
}
}
if (reset_data) { if (reset_data) {
lapply(cv_booster$boosters, function(fd) { lapply(cv_booster$boosters, function(fd) {
# Store temporarily model data elsewhere # Store temporarily model data elsewhere
...@@ -318,6 +329,7 @@ lgb.cv <- function(params = list(), ...@@ -318,6 +329,7 @@ lgb.cv <- function(params = list(),
fd$booster$record_evals <- booster_old$record_evals fd$booster$record_evals <- booster_old$record_evals
}) })
} }
# Return booster # Return booster
return(cv_booster) return(cv_booster)
......
...@@ -268,6 +268,17 @@ lgb.train <- function(params = list(), ...@@ -268,6 +268,17 @@ lgb.train <- function(params = list(),
} }
# When early stopping is not activated, we compute the best iteration / score ourselves by selecting the first metric and the first dataset
if (record && length(valids) > 0 && is.na(env$best_score)) {
if (env$eval_list[[1]]$higher_better[1] == TRUE) {
booster$best_iter <- unname(which.max(unlist(booster$record_evals[[2]][[1]][[1]])))
booster$best_score <- booster$record_evals[[2]][[1]][[1]][[booster$best_iter]]
} else {
booster$best_iter <- unname(which.min(unlist(booster$record_evals[[2]][[1]][[1]])))
booster$best_score <- booster$record_evals[[2]][[1]][[1]][[booster$best_iter]]
}
}
# Check for booster model conversion to predictor model # Check for booster model conversion to predictor model
if (reset_data) { if (reset_data) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment