Commit f70a0532 authored by Laurae's avatar Laurae Committed by James Lamb
Browse files

[R-package] Fix best_iter and best_score (#2159)

* Callback for NA handling

* lgb.Booster default score => NA

* lgb.cv default best score => NA

* Fix back callback

* lgb.train with booster check at the end

manual tests done: 
* With early stopping + with validation set
* With early stopping + without validation set
* Without early stopping + with validation set
* Without early stopping + without validation set

And with multiple metrics / validation sets.

* lgb.cv with booster check at the end

manual tests done: 
* With early stopping + with validation set
* With early stopping + without validation set
* Without early stopping + with validation set
* Without early stopping + without validation set

And with multiple metrics / validation sets.
parent 2459362a
...@@ -10,80 +10,80 @@ CB_ENV <- R6::R6Class( ...@@ -10,80 +10,80 @@ CB_ENV <- R6::R6Class(
eval_list = list(), eval_list = list(),
eval_err_list = list(), eval_err_list = list(),
best_iter = -1, best_iter = -1,
best_score = -1, best_score = NA,
met_early_stop = FALSE met_early_stop = FALSE
) )
) )
cb.reset.parameters <- function(new_params) { cb.reset.parameters <- function(new_params) {
# Check for parameter list # Check for parameter list
if (!is.list(new_params)) { if (!is.list(new_params)) {
stop(sQuote("new_params"), " must be a list") stop(sQuote("new_params"), " must be a list")
} }
# Deparse parameter list # Deparse parameter list
pnames <- gsub("\\.", "_", names(new_params)) pnames <- gsub("\\.", "_", names(new_params))
nrounds <- NULL nrounds <- NULL
# Run some checks in the beginning # Run some checks in the beginning
init <- function(env) { init <- function(env) {
# Store boosting rounds # Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1 nrounds <<- env$end_iteration - env$begin_iteration + 1
# Check for model environment # Check for model environment
if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) } if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
# Some parameters are not allowed to be changed, # Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos # since changing them would simply wreck some chaos
not_allowed <- c("num_class", "metric", "boosting_type") not_allowed <- c("num_class", "metric", "boosting_type")
if (any(pnames %in% not_allowed)) { if (any(pnames %in% not_allowed)) {
stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting") stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
} }
# Check parameter names # Check parameter names
for (n in pnames) { for (n in pnames) {
# Set name # Set name
p <- new_params[[n]] p <- new_params[[n]]
# Check if function for parameter # Check if function for parameter
if (is.function(p)) { if (is.function(p)) {
# Check if requires at least two arguments # Check if requires at least two arguments
if (length(formals(p)) != 2) { if (length(formals(p)) != 2) {
stop("Parameter ", sQuote(n), " is a function but not of two arguments") stop("Parameter ", sQuote(n), " is a function but not of two arguments")
} }
# Check if numeric or character # Check if numeric or character
} else if (is.numeric(p) || is.character(p)) { } else if (is.numeric(p) || is.character(p)) {
# Check if length is matching # Check if length is matching
if (length(p) != nrounds) { if (length(p) != nrounds) {
stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds")) stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
} }
} else { } else {
stop("Parameter ", sQuote(n), " is not a function or a vector") stop("Parameter ", sQuote(n), " is not a function or a vector")
} }
} }
} }
callback <- function(env) { callback <- function(env) {
# Check if rounds is null # Check if rounds is null
if (is.null(nrounds)) { if (is.null(nrounds)) {
init(env) init(env)
} }
# Store iteration # Store iteration
i <- env$iteration - env$begin_iteration i <- env$iteration - env$begin_iteration
# Apply list on parameters # Apply list on parameters
pars <- lapply(new_params, function(p) { pars <- lapply(new_params, function(p) {
if (is.function(p)) { if (is.function(p)) {
...@@ -91,14 +91,14 @@ cb.reset.parameters <- function(new_params) { ...@@ -91,14 +91,14 @@ cb.reset.parameters <- function(new_params) {
} }
p[i] p[i]
}) })
# To-do check pars # To-do check pars
if (!is.null(env$model)) { if (!is.null(env$model)) {
env$model$reset_parameter(pars) env$model$reset_parameter(pars)
} }
} }
attr(callback, "call") <- match.call() attr(callback, "call") <- match.call()
attr(callback, "is_pre_iteration") <- TRUE attr(callback, "is_pre_iteration") <- TRUE
attr(callback, "name") <- "cb.reset.parameters" attr(callback, "name") <- "cb.reset.parameters"
...@@ -107,327 +107,328 @@ cb.reset.parameters <- function(new_params) { ...@@ -107,327 +107,328 @@ cb.reset.parameters <- function(new_params) {
# Format the evaluation metric string # Format the evaluation metric string
format.eval.string <- function(eval_res, eval_err = NULL) { format.eval.string <- function(eval_res, eval_err = NULL) {
# Check for empty evaluation string # Check for empty evaluation string
if (is.null(eval_res) || length(eval_res) == 0) { if (is.null(eval_res) || length(eval_res) == 0) {
stop("no evaluation results") stop("no evaluation results")
} }
# Check for empty evaluation error # Check for empty evaluation error
if (!is.null(eval_err)) { if (!is.null(eval_err)) {
sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err) sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
} else { } else {
sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value) sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
} }
} }
merge.eval.string <- function(env) { merge.eval.string <- function(env) {
# Check length of evaluation list # Check length of evaluation list
if (length(env$eval_list) <= 0) { if (length(env$eval_list) <= 0) {
return("") return("")
} }
# Get evaluation # Get evaluation
msg <- list(sprintf("[%d]:", env$iteration)) msg <- list(sprintf("[%d]:", env$iteration))
# Set if evaluation error # Set if evaluation error
is_eval_err <- length(env$eval_err_list) > 0 is_eval_err <- length(env$eval_err_list) > 0
# Loop through evaluation list # Loop through evaluation list
for (j in seq_along(env$eval_list)) { for (j in seq_along(env$eval_list)) {
# Store evaluation error # Store evaluation error
eval_err <- NULL eval_err <- NULL
if (is_eval_err) { if (is_eval_err) {
eval_err <- env$eval_err_list[[j]] eval_err <- env$eval_err_list[[j]]
} }
# Set error message # Set error message
msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err)) msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
} }
# Return tabulated separated message # Return tabulated separated message
paste0(msg, collapse = "\t") paste0(msg, collapse = "\t")
} }
cb.print.evaluation <- function(period = 1) { cb.print.evaluation <- function(period = 1) {
# Create callback # Create callback
callback <- function(env) { callback <- function(env) {
# Check if period is at least 1 or more # Check if period is at least 1 or more
if (period > 0) { if (period > 0) {
# Store iteration # Store iteration
i <- env$iteration i <- env$iteration
# Check if iteration matches moduo # Check if iteration matches moduo
if ((i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration ))) { if ((i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration ))) {
# Merge evaluation string # Merge evaluation string
msg <- merge.eval.string(env) msg <- merge.eval.string(env)
# Check if message is existing # Check if message is existing
if (nchar(msg) > 0) { if (nchar(msg) > 0) {
cat(merge.eval.string(env), "\n") cat(merge.eval.string(env), "\n")
} }
} }
} }
} }
# Store attributes # Store attributes
attr(callback, "call") <- match.call() attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.print.evaluation" attr(callback, "name") <- "cb.print.evaluation"
# Return callback # Return callback
callback callback
} }
cb.record.evaluation <- function() { cb.record.evaluation <- function() {
# Create callback # Create callback
callback <- function(env) { callback <- function(env) {
# Return empty if empty evaluation list # Return empty if empty evaluation list
if (length(env$eval_list) <= 0) { if (length(env$eval_list) <= 0) {
return() return()
} }
# Set if evaluation error # Set if evaluation error
is_eval_err <- length(env$eval_err_list) > 0 is_eval_err <- length(env$eval_err_list) > 0
# Check length of recorded evaluation # Check length of recorded evaluation
if (length(env$model$record_evals) == 0) { if (length(env$model$record_evals) == 0) {
# Loop through each evaluation list element # Loop through each evaluation list element
for (j in seq_along(env$eval_list)) { for (j in seq_along(env$eval_list)) {
# Store names # Store names
data_name <- env$eval_list[[j]]$data_name data_name <- env$eval_list[[j]]$data_name
name <- env$eval_list[[j]]$name name <- env$eval_list[[j]]$name
env$model$record_evals$start_iter <- env$begin_iteration env$model$record_evals$start_iter <- env$begin_iteration
# Check if evaluation record exists # Check if evaluation record exists
if (is.null(env$model$record_evals[[data_name]])) { if (is.null(env$model$record_evals[[data_name]])) {
env$model$record_evals[[data_name]] <- list() env$model$record_evals[[data_name]] <- list()
} }
# Create dummy lists # Create dummy lists
env$model$record_evals[[data_name]][[name]] <- list() env$model$record_evals[[data_name]][[name]] <- list()
env$model$record_evals[[data_name]][[name]]$eval <- list() env$model$record_evals[[data_name]][[name]]$eval <- list()
env$model$record_evals[[data_name]][[name]]$eval_err <- list() env$model$record_evals[[data_name]][[name]]$eval_err <- list()
} }
} }
# Loop through each evaluation list element # Loop through each evaluation list element
for (j in seq_along(env$eval_list)) { for (j in seq_along(env$eval_list)) {
# Get evaluation data # Get evaluation data
eval_res <- env$eval_list[[j]] eval_res <- env$eval_list[[j]]
eval_err <- NULL eval_err <- NULL
if (is_eval_err) { if (is_eval_err) {
eval_err <- env$eval_err_list[[j]] eval_err <- env$eval_err_list[[j]]
} }
# Store names # Store names
data_name <- eval_res$data_name data_name <- eval_res$data_name
name <- eval_res$name name <- eval_res$name
# Store evaluation data # Store evaluation data
env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value) env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err) env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
} }
} }
# Store attributes # Store attributes
attr(callback, "call") <- match.call() attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.record.evaluation" attr(callback, "name") <- "cb.record.evaluation"
# Return callback # Return callback
callback callback
} }
cb.early.stop <- function(stopping_rounds, verbose = TRUE) { cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Initialize variables # Initialize variables
factor_to_bigger_better <- NULL factor_to_bigger_better <- NULL
best_iter <- NULL best_iter <- NULL
best_score <- NULL best_score <- NULL
best_msg <- NULL best_msg <- NULL
eval_len <- NULL eval_len <- NULL
# Initalization function # Initalization function
init <- function(env) { init <- function(env) {
# Store evaluation length # Store evaluation length
eval_len <<- length(env$eval_list) eval_len <<- length(env$eval_list)
# Early stopping cannot work without metrics # Early stopping cannot work without metrics
if (eval_len == 0) { if (eval_len == 0) {
stop("For early stopping, valids must have at least one element") stop("For early stopping, valids must have at least one element")
} }
# Check if verbose or not # Check if verbose or not
if (isTRUE(verbose)) { if (isTRUE(verbose)) {
cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "") cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
} }
# Maximization or minimization task # Maximization or minimization task
factor_to_bigger_better <<- rep.int(1.0, eval_len) factor_to_bigger_better <<- rep.int(1.0, eval_len)
best_iter <<- rep.int(-1, eval_len) best_iter <<- rep.int(-1, eval_len)
best_score <<- rep.int(-Inf, eval_len) best_score <<- rep.int(-Inf, eval_len)
best_msg <<- list() best_msg <<- list()
# Loop through evaluation elements # Loop through evaluation elements
for (i in seq_len(eval_len)) { for (i in seq_len(eval_len)) {
# Prepend message # Prepend message
best_msg <<- c(best_msg, "") best_msg <<- c(best_msg, "")
# Check if maximization or minimization # Check if maximization or minimization
if (!env$eval_list[[i]]$higher_better) { if (!env$eval_list[[i]]$higher_better) {
factor_to_bigger_better[i] <<- -1.0 factor_to_bigger_better[i] <<- -1.0
} }
} }
} }
# Create callback # Create callback
callback <- function(env, finalize = FALSE) { callback <- function(env, finalize = FALSE) {
# Check for empty evaluation # Check for empty evaluation
if (is.null(eval_len)) { if (is.null(eval_len)) {
init(env) init(env)
} }
# Store iteration # Store iteration
cur_iter <- env$iteration cur_iter <- env$iteration
# Loop through evaluation # Loop through evaluation
for (i in seq_len(eval_len)) { for (i in seq_len(eval_len)) {
# Store score # Store score
score <- env$eval_list[[i]]$value * factor_to_bigger_better[i] score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
# Check if score is better # Check if score is better
if (score > best_score[i]) { if (score > best_score[i]) {
# Store new scores # Store new scores
best_score[i] <<- score best_score[i] <<- score
best_iter[i] <<- cur_iter best_iter[i] <<- cur_iter
# Prepare to print if verbose # Prepare to print if verbose
if (verbose) { if (verbose) {
best_msg[[i]] <<- as.character(merge.eval.string(env)) best_msg[[i]] <<- as.character(merge.eval.string(env))
}
} else {
# Check if early stopping is required
if (cur_iter - best_iter[i] >= stopping_rounds) {
# Check if model is not null
if (!is.null(env$model)) {
env$model$best_score <- best_score[i]
env$model$best_iter <- best_iter[i]
} }
# Print message if verbose } else {
if (isTRUE(verbose)) {
# Check if early stopping is required
cat("Early stopping, best iteration is:", "\n") if (cur_iter - best_iter[i] >= stopping_rounds) {
cat(best_msg[[i]], "\n")
# Check if model is not null
if (!is.null(env$model)) {
env$model$best_score <- best_score[i]
env$model$best_iter <- best_iter[i]
}
# Print message if verbose
if (isTRUE(verbose)) {
cat("Early stopping, best iteration is:", "\n")
cat(best_msg[[i]], "\n")
}
# Store best iteration and stop
env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE
} }
# Store best iteration and stop
env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE
} }
}
if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) { if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
# Check if model is not null # Check if model is not null
if (!is.null(env$model)) { if (!is.null(env$model)) {
env$model$best_score <- best_score[i] env$model$best_score <- best_score[i]
env$model$best_iter <- best_iter[i] env$model$best_iter <- best_iter[i]
} }
# Print message if verbose # Print message if verbose
if (isTRUE(verbose)) { if (isTRUE(verbose)) {
cat("Did not meet early stopping, best iteration is:", "\n") cat("Did not meet early stopping, best iteration is:", "\n")
cat(best_msg[[i]], "\n") cat(best_msg[[i]], "\n")
} }
# Store best iteration and stop # Store best iteration and stop
env$best_iter <- best_iter[i] env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE env$met_early_stop <- TRUE
} }
} }
} }
# Set attributes # Set attributes
attr(callback, "call") <- match.call() attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.early.stop" attr(callback, "name") <- "cb.early.stop"
# Return callback # Return callback
callback callback
} }
# Extract callback names from the list of callbacks # Extract callback names from the list of callbacks
callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) } callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
add.cb <- function(cb_list, cb) { add.cb <- function(cb_list, cb) {
# Combine two elements # Combine two elements
cb_list <- c(cb_list, cb) cb_list <- c(cb_list, cb)
# Set names of elements # Set names of elements
names(cb_list) <- callback.names(cb_list) names(cb_list) <- callback.names(cb_list)
# Check for existence # Check for existence
if ("cb.early.stop" %in% names(cb_list)) { if ("cb.early.stop" %in% names(cb_list)) {
# Concatenate existing elements # Concatenate existing elements
cb_list <- c(cb_list, cb_list["cb.early.stop"]) cb_list <- c(cb_list, cb_list["cb.early.stop"])
# Remove only the first one # Remove only the first one
cb_list["cb.early.stop"] <- NULL cb_list["cb.early.stop"] <- NULL
} }
# Return element # Return element
cb_list cb_list
} }
categorize.callbacks <- function(cb_list) { categorize.callbacks <- function(cb_list) {
# Check for pre-iteration or post-iteration # Check for pre-iteration or post-iteration
list( list(
pre_iter = Filter(function(x) { pre_iter = Filter(function(x) {
pre <- attr(x, "is_pre_iteration") pre <- attr(x, "is_pre_iteration")
!is.null(pre) && pre !is.null(pre) && pre
}, cb_list), }, cb_list),
post_iter = Filter(function(x) { post_iter = Filter(function(x) {
pre <- attr(x, "is_pre_iteration") pre <- attr(x, "is_pre_iteration")
is.null(pre) || !pre is.null(pre) || !pre
}, cb_list) }, cb_list)
) )
} }
...@@ -5,7 +5,7 @@ Booster <- R6::R6Class( ...@@ -5,7 +5,7 @@ Booster <- R6::R6Class(
public = list( public = list(
best_iter = -1, best_iter = -1,
best_score = -1, best_score = NA,
record_evals = list(), record_evals = list(),
# Finalize will free up the handles # Finalize will free up the handles
......
...@@ -4,7 +4,7 @@ CVBooster <- R6::R6Class( ...@@ -4,7 +4,7 @@ CVBooster <- R6::R6Class(
cloneable = FALSE, cloneable = FALSE,
public = list( public = list(
best_iter = -1, best_iter = -1,
best_score = -1, best_score = NA,
record_evals = list(), record_evals = list(),
boosters = list(), boosters = list(),
initialize = function(x) { initialize = function(x) {
...@@ -90,7 +90,7 @@ lgb.cv <- function(params = list(), ...@@ -90,7 +90,7 @@ lgb.cv <- function(params = list(),
callbacks = list(), callbacks = list(),
reset_data = FALSE, reset_data = FALSE,
...) { ...) {
# Setup temporary variables # Setup temporary variables
addiction_params <- list(...) addiction_params <- list(...)
params <- append(params, addiction_params) params <- append(params, addiction_params)
...@@ -99,35 +99,35 @@ lgb.cv <- function(params = list(), ...@@ -99,35 +99,35 @@ lgb.cv <- function(params = list(),
params <- lgb.check.eval(params, eval) params <- lgb.check.eval(params, eval)
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if (nrounds <= 0) { if (nrounds <= 0) {
stop("nrounds should be greater than zero") stop("nrounds should be greater than zero")
} }
# Check for objective (function or not) # Check for objective (function or not)
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
} }
# Check for loss (function or not) # Check for loss (function or not)
if (is.function(eval)) { if (is.function(eval)) {
feval <- eval feval <- eval
} }
# Check for parameters # Check for parameters
lgb.check.params(params) lgb.check.params(params)
# Init predictor to empty # Init predictor to empty
predictor <- NULL predictor <- NULL
# Check for boosting from a trained model # Check for boosting from a trained model
if (is.character(init_model)) { if (is.character(init_model)) {
predictor <- Predictor$new(init_model) predictor <- Predictor$new(init_model)
} else if (lgb.is.Booster(init_model)) { } else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor() predictor <- init_model$to_predictor()
} }
# Set the iteration to start from / end to (and check for boosting from a trained model, again) # Set the iteration to start from / end to (and check for boosting from a trained model, again)
begin_iteration <- 1 begin_iteration <- 1
if (!is.null(predictor)) { if (!is.null(predictor)) {
...@@ -140,7 +140,7 @@ lgb.cv <- function(params = list(), ...@@ -140,7 +140,7 @@ lgb.cv <- function(params = list(),
} else { } else {
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
} }
# Check for training dataset type correctness # Check for training dataset type correctness
if (!lgb.is.Dataset(data)) { if (!lgb.is.Dataset(data)) {
if (is.null(label)) { if (is.null(label)) {
...@@ -148,49 +148,49 @@ lgb.cv <- function(params = list(), ...@@ -148,49 +148,49 @@ lgb.cv <- function(params = list(),
} }
data <- lgb.Dataset(data, label = label) data <- lgb.Dataset(data, label = label)
} }
# Check for weights # Check for weights
if (!is.null(weight)) { if (!is.null(weight)) {
data$setinfo("weight", weight) data$setinfo("weight", weight)
} }
# Update parameters with parsed parameters # Update parameters with parsed parameters
data$update_params(params) data$update_params(params)
# Create the predictor set # Create the predictor set
data$.__enclos_env__$private$set_predictor(predictor) data$.__enclos_env__$private$set_predictor(predictor)
# Write column names # Write column names
if (!is.null(colnames)) { if (!is.null(colnames)) {
data$set_colnames(colnames) data$set_colnames(colnames)
} }
# Write categorical features # Write categorical features
if (!is.null(categorical_feature)) { if (!is.null(categorical_feature)) {
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
} }
# Construct datasets, if needed # Construct datasets, if needed
data$construct() data$construct()
# Check for folds # Check for folds
if (!is.null(folds)) { if (!is.null(folds)) {
# Check for list of folds or for single value # Check for list of folds or for single value
if (!is.list(folds) || length(folds) < 2) { if (!is.list(folds) || length(folds) < 2) {
stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold") stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold")
} }
# Set number of folds # Set number of folds
nfold <- length(folds) nfold <- length(folds)
} else { } else {
# Check fold value # Check fold value
if (nfold <= 1) { if (nfold <= 1) {
stop(sQuote("nfold"), " must be > 1") stop(sQuote("nfold"), " must be > 1")
} }
# Create folds # Create folds
folds <- generate.cv.folds(nfold, folds <- generate.cv.folds(nfold,
nrow(data), nrow(data),
...@@ -198,19 +198,19 @@ lgb.cv <- function(params = list(), ...@@ -198,19 +198,19 @@ lgb.cv <- function(params = list(),
getinfo(data, "label"), getinfo(data, "label"),
getinfo(data, "group"), getinfo(data, "group"),
params) params)
} }
# Add printing log callback # Add printing log callback
if (verbose > 0 && eval_freq > 0) { if (verbose > 0 && eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
# Add evaluation log callback # Add evaluation log callback
if (record) { if (record) {
callbacks <- add.cb(callbacks, cb.record.evaluation()) callbacks <- add.cb(callbacks, cb.record.evaluation())
} }
# Check for early stopping passed as parameter when adding early stopping callback # Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping") early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping")
if (any(names(params) %in% early_stop)) { if (any(names(params) %in% early_stop)) {
...@@ -224,10 +224,10 @@ lgb.cv <- function(params = list(), ...@@ -224,10 +224,10 @@ lgb.cv <- function(params = list(),
} }
} }
} }
# Categorize callbacks # Categorize callbacks
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
# Construct booster using a list apply, check if requires group or not # Construct booster using a list apply, check if requires group or not
if (!is.list(folds[[1]])) { if (!is.list(folds[[1]])) {
bst_folds <- lapply(seq_along(folds), function(k) { bst_folds <- lapply(seq_along(folds), function(k) {
...@@ -256,55 +256,66 @@ lgb.cv <- function(params = list(), ...@@ -256,55 +256,66 @@ lgb.cv <- function(params = list(),
list(booster = booster) list(booster = booster)
}) })
} }
# Create new booster # Create new booster
cv_booster <- CVBooster$new(bst_folds) cv_booster <- CVBooster$new(bst_folds)
# Callback env # Callback env
env <- CB_ENV$new() env <- CB_ENV$new()
env$model <- cv_booster env$model <- cv_booster
env$begin_iteration <- begin_iteration env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
# Start training model using number of iterations to start and end with # Start training model using number of iterations to start and end with
for (i in seq.int(from = begin_iteration, to = end_iteration)) { for (i in seq.int(from = begin_iteration, to = end_iteration)) {
# Overwrite iteration in environment # Overwrite iteration in environment
env$iteration <- i env$iteration <- i
env$eval_list <- list() env$eval_list <- list()
# Loop through "pre_iter" element # Loop through "pre_iter" element
for (f in cb$pre_iter) { for (f in cb$pre_iter) {
f(env) f(env)
} }
# Update one boosting iteration # Update one boosting iteration
msg <- lapply(cv_booster$boosters, function(fd) { msg <- lapply(cv_booster$boosters, function(fd) {
fd$booster$update(fobj = fobj) fd$booster$update(fobj = fobj)
fd$booster$eval_valid(feval = feval) fd$booster$eval_valid(feval = feval)
}) })
# Prepare collection of evaluation results # Prepare collection of evaluation results
merged_msg <- lgb.merge.cv.result(msg) merged_msg <- lgb.merge.cv.result(msg)
# Write evaluation result in environment # Write evaluation result in environment
env$eval_list <- merged_msg$eval_list env$eval_list <- merged_msg$eval_list
# Check for standard deviation requirement # Check for standard deviation requirement
if(showsd) { if(showsd) {
env$eval_err_list <- merged_msg$eval_err_list env$eval_err_list <- merged_msg$eval_err_list
} }
# Loop through env # Loop through env
for (f in cb$post_iter) { for (f in cb$post_iter) {
f(env) f(env)
} }
# Check for early stopping and break if needed # Check for early stopping and break if needed
if (env$met_early_stop) break if (env$met_early_stop) break
} }
if (record && is.na(env$best_score)) {
if (env$eval_list[[1]]$higher_better[1] == TRUE) {
cv_booster$best_iter <- unname(which.max(unlist(cv_booster$record_evals[[2]][[1]][[1]])))
cv_booster$best_score <- cv_booster$record_evals[[2]][[1]][[1]][[cv_booster$best_iter]]
} else {
cv_booster$best_iter <- unname(which.min(unlist(cv_booster$record_evals[[2]][[1]][[1]])))
cv_booster$best_score <- cv_booster$record_evals[[2]][[1]][[1]][[cv_booster$best_iter]]
}
}
if (reset_data) { if (reset_data) {
lapply(cv_booster$boosters, function(fd) { lapply(cv_booster$boosters, function(fd) {
# Store temporarily model data elsewhere # Store temporarily model data elsewhere
...@@ -318,57 +329,58 @@ lgb.cv <- function(params = list(), ...@@ -318,57 +329,58 @@ lgb.cv <- function(params = list(),
fd$booster$record_evals <- booster_old$record_evals fd$booster$record_evals <- booster_old$record_evals
}) })
} }
# Return booster # Return booster
return(cv_booster) return(cv_booster)
} }
# Generates random (stratified if needed) CV folds # Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
# Check for group existence # Check for group existence
if (is.null(group)) { if (is.null(group)) {
# Shuffle # Shuffle
rnd_idx <- sample.int(nrows) rnd_idx <- sample.int(nrows)
# Request stratified folds # Request stratified folds
if (isTRUE(stratified) && params$objective %in% c("binary", "multiclass") && length(label) == length(rnd_idx)) { if (isTRUE(stratified) && params$objective %in% c("binary", "multiclass") && length(label) == length(rnd_idx)) {
y <- label[rnd_idx] y <- label[rnd_idx]
y <- factor(y) y <- factor(y)
folds <- lgb.stratified.folds(y, nfold) folds <- lgb.stratified.folds(y, nfold)
} else { } else {
# Make simple non-stratified folds # Make simple non-stratified folds
folds <- list() folds <- list()
# Loop through each fold # Loop through each fold
for (i in seq_len(nfold)) { for (i in seq_len(nfold)) {
kstep <- length(rnd_idx) %/% (nfold - i + 1) kstep <- length(rnd_idx) %/% (nfold - i + 1)
folds[[i]] <- rnd_idx[seq_len(kstep)] folds[[i]] <- rnd_idx[seq_len(kstep)]
rnd_idx <- rnd_idx[-seq_len(kstep)] rnd_idx <- rnd_idx[-seq_len(kstep)]
} }
} }
} else { } else {
# When doing group, stratified is not possible (only random selection) # When doing group, stratified is not possible (only random selection)
if (nfold > length(group)) { if (nfold > length(group)) {
stop("\n\tYou requested too many folds for the number of available groups.\n") stop("\n\tYou requested too many folds for the number of available groups.\n")
} }
# Degroup the groups # Degroup the groups
ungrouped <- inverse.rle(list(lengths = group, values = seq_along(group))) ungrouped <- inverse.rle(list(lengths = group, values = seq_along(group)))
# Can't stratify, shuffle # Can't stratify, shuffle
rnd_idx <- sample.int(length(group)) rnd_idx <- sample.int(length(group))
# Make simple non-stratified folds # Make simple non-stratified folds
folds <- list() folds <- list()
# Loop through each fold # Loop through each fold
for (i in seq_len(nfold)) { for (i in seq_len(nfold)) {
kstep <- length(rnd_idx) %/% (nfold - i + 1) kstep <- length(rnd_idx) %/% (nfold - i + 1)
...@@ -376,12 +388,12 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { ...@@ -376,12 +388,12 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
group = rnd_idx[seq_len(kstep)]) group = rnd_idx[seq_len(kstep)])
rnd_idx <- rnd_idx[-seq_len(kstep)] rnd_idx <- rnd_idx[-seq_len(kstep)]
} }
} }
# Return folds # Return folds
return(folds) return(folds)
} }
# Creates CV folds stratified by the values of y. # Creates CV folds stratified by the values of y.
...@@ -389,7 +401,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { ...@@ -389,7 +401,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
# by always returning an unnamed list of fold indices. # by always returning an unnamed list of fold indices.
#' @importFrom stats quantile #' @importFrom stats quantile
lgb.stratified.folds <- function(y, k = 10) { lgb.stratified.folds <- function(y, k = 10) {
## Group the numeric data based on their magnitudes ## Group the numeric data based on their magnitudes
## and sample within those groups. ## and sample within those groups.
## When the number of samples is low, we may have ## When the number of samples is low, we may have
...@@ -399,51 +411,51 @@ lgb.stratified.folds <- function(y, k = 10) { ...@@ -399,51 +411,51 @@ lgb.stratified.folds <- function(y, k = 10) {
## At most, we will use quantiles. If the sample ## At most, we will use quantiles. If the sample
## is too small, we just do regular unstratified CV ## is too small, we just do regular unstratified CV
if (is.numeric(y)) { if (is.numeric(y)) {
cuts <- length(y) %/% k cuts <- length(y) %/% k
if (cuts < 2) { cuts <- 2 } if (cuts < 2) { cuts <- 2 }
if (cuts > 5) { cuts <- 5 } if (cuts > 5) { cuts <- 5 }
y <- cut(y, y <- cut(y,
unique(stats::quantile(y, probs = seq.int(0, 1, length.out = cuts))), unique(stats::quantile(y, probs = seq.int(0, 1, length.out = cuts))),
include.lowest = TRUE) include.lowest = TRUE)
} }
if (k < length(y)) { if (k < length(y)) {
## Reset levels so that the possible levels and ## Reset levels so that the possible levels and
## the levels in the vector are the same ## the levels in the vector are the same
y <- factor(as.character(y)) y <- factor(as.character(y))
numInClass <- table(y) numInClass <- table(y)
foldVector <- vector(mode = "integer", length(y)) foldVector <- vector(mode = "integer", length(y))
## For each class, balance the fold allocation as far ## For each class, balance the fold allocation as far
## as possible, then resample the remainder. ## as possible, then resample the remainder.
## The final assignment of folds is also randomized. ## The final assignment of folds is also randomized.
for (i in seq_along(numInClass)) { for (i in seq_along(numInClass)) {
## Create a vector of integers from 1:k as many times as possible without ## Create a vector of integers from 1:k as many times as possible without
## going over the number of samples in the class. Note that if the number ## going over the number of samples in the class. Note that if the number
## of samples in a class is less than k, nothing is producd here. ## of samples in a class is less than k, nothing is producd here.
seqVector <- rep(seq_len(k), numInClass[i] %/% k) seqVector <- rep(seq_len(k), numInClass[i] %/% k)
## Add enough random integers to get length(seqVector) == numInClass[i] ## Add enough random integers to get length(seqVector) == numInClass[i]
if (numInClass[i] %% k > 0) { if (numInClass[i] %% k > 0) {
seqVector <- c(seqVector, sample.int(k, numInClass[i] %% k)) seqVector <- c(seqVector, sample.int(k, numInClass[i] %% k))
} }
## Shuffle the integers for fold assignment and assign to this classes's data ## Shuffle the integers for fold assignment and assign to this classes's data
foldVector[y == dimnames(numInClass)$y[i]] <- sample(seqVector) foldVector[y == dimnames(numInClass)$y[i]] <- sample(seqVector)
} }
} else { } else {
foldVector <- seq(along = y) foldVector <- seq(along = y)
} }
# Return data # Return data
out <- split(seq(along = y), foldVector) out <- split(seq(along = y), foldVector)
names(out) <- NULL names(out) <- NULL
...@@ -451,53 +463,53 @@ lgb.stratified.folds <- function(y, k = 10) { ...@@ -451,53 +463,53 @@ lgb.stratified.folds <- function(y, k = 10) {
} }
lgb.merge.cv.result <- function(msg, showsd = TRUE) { lgb.merge.cv.result <- function(msg, showsd = TRUE) {
# Get CV message length # Get CV message length
if (length(msg) == 0) { if (length(msg) == 0) {
stop("lgb.cv: size of cv result error") stop("lgb.cv: size of cv result error")
} }
# Get evaluation message length # Get evaluation message length
eval_len <- length(msg[[1]]) eval_len <- length(msg[[1]])
# Is evaluation message empty? # Is evaluation message empty?
if (eval_len == 0) { if (eval_len == 0) {
stop("lgb.cv: should provide at least one metric for CV") stop("lgb.cv: should provide at least one metric for CV")
} }
# Get evaluation results using a list apply # Get evaluation results using a list apply
eval_result <- lapply(seq_len(eval_len), function(j) { eval_result <- lapply(seq_len(eval_len), function(j) {
as.numeric(lapply(seq_along(msg), function(i) { as.numeric(lapply(seq_along(msg), function(i) {
msg[[i]][[j]]$value })) msg[[i]][[j]]$value }))
}) })
# Get evaluation # Get evaluation
ret_eval <- msg[[1]] ret_eval <- msg[[1]]
# Go through evaluation length items # Go through evaluation length items
for (j in seq_len(eval_len)) { for (j in seq_len(eval_len)) {
ret_eval[[j]]$value <- mean(eval_result[[j]]) ret_eval[[j]]$value <- mean(eval_result[[j]])
} }
# Preinit evaluation error # Preinit evaluation error
ret_eval_err <- NULL ret_eval_err <- NULL
# Check for standard deviation # Check for standard deviation
if (showsd) { if (showsd) {
# Parse standard deviation # Parse standard deviation
for (j in seq_len(eval_len)) { for (j in seq_len(eval_len)) {
ret_eval_err <- c(ret_eval_err, ret_eval_err <- c(ret_eval_err,
sqrt(mean(eval_result[[j]] ^ 2) - mean(eval_result[[j]]) ^ 2)) sqrt(mean(eval_result[[j]] ^ 2) - mean(eval_result[[j]]) ^ 2))
} }
# Convert to list # Convert to list
ret_eval_err <- as.list(ret_eval_err) ret_eval_err <- as.list(ret_eval_err)
} }
# Return errors # Return errors
list(eval_list = ret_eval, list(eval_list = ret_eval,
eval_err_list = ret_eval_err) eval_err_list = ret_eval_err)
} }
...@@ -62,7 +62,7 @@ lgb.train <- function(params = list(), ...@@ -62,7 +62,7 @@ lgb.train <- function(params = list(),
callbacks = list(), callbacks = list(),
reset_data = FALSE, reset_data = FALSE,
...) { ...) {
# Setup temporary variables # Setup temporary variables
additional_params <- list(...) additional_params <- list(...)
params <- append(params, additional_params) params <- append(params, additional_params)
...@@ -71,35 +71,35 @@ lgb.train <- function(params = list(), ...@@ -71,35 +71,35 @@ lgb.train <- function(params = list(),
params <- lgb.check.eval(params, eval) params <- lgb.check.eval(params, eval)
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if (nrounds <= 0) { if (nrounds <= 0) {
stop("nrounds should be greater than zero") stop("nrounds should be greater than zero")
} }
# Check for objective (function or not) # Check for objective (function or not)
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
} }
# Check for loss (function or not) # Check for loss (function or not)
if (is.function(eval)) { if (is.function(eval)) {
feval <- eval feval <- eval
} }
# Check for parameters # Check for parameters
lgb.check.params(params) lgb.check.params(params)
# Init predictor to empty # Init predictor to empty
predictor <- NULL predictor <- NULL
# Check for boosting from a trained model # Check for boosting from a trained model
if (is.character(init_model)) { if (is.character(init_model)) {
predictor <- Predictor$new(init_model) predictor <- Predictor$new(init_model)
} else if (lgb.is.Booster(init_model)) { } else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor() predictor <- init_model$to_predictor()
} }
# Set the iteration to start from / end to (and check for boosting from a trained model, again) # Set the iteration to start from / end to (and check for boosting from a trained model, again)
begin_iteration <- 1 begin_iteration <- 1
if (!is.null(predictor)) { if (!is.null(predictor)) {
...@@ -112,89 +112,89 @@ lgb.train <- function(params = list(), ...@@ -112,89 +112,89 @@ lgb.train <- function(params = list(),
} else { } else {
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
} }
# Check for training dataset type correctness # Check for training dataset type correctness
if (!lgb.is.Dataset(data)) { if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object") stop("lgb.train: data only accepts lgb.Dataset object")
} }
# Check for validation dataset type correctness # Check for validation dataset type correctness
if (length(valids) > 0) { if (length(valids) > 0) {
# One or more validation dataset # One or more validation dataset
# Check for list as input and type correctness by object # Check for list as input and type correctness by object
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1)))) { if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements") stop("lgb.train: valids must be a list of lgb.Dataset elements")
} }
# Attempt to get names # Attempt to get names
evnames <- names(valids) evnames <- names(valids)
# Check for names existance # Check for names existance
if (is.null(evnames) || !all(nzchar(evnames))) { if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of the valids must have a name tag") stop("lgb.train: each element of the valids must have a name tag")
} }
} }
# Update parameters with parsed parameters # Update parameters with parsed parameters
data$update_params(params) data$update_params(params)
# Create the predictor set # Create the predictor set
data$.__enclos_env__$private$set_predictor(predictor) data$.__enclos_env__$private$set_predictor(predictor)
# Write column names # Write column names
if (!is.null(colnames)) { if (!is.null(colnames)) {
data$set_colnames(colnames) data$set_colnames(colnames)
} }
# Write categorical features # Write categorical features
if (!is.null(categorical_feature)) { if (!is.null(categorical_feature)) {
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
} }
# Construct datasets, if needed # Construct datasets, if needed
data$construct() data$construct()
vaild_contain_train <- FALSE vaild_contain_train <- FALSE
train_data_name <- "train" train_data_name <- "train"
reduced_valid_sets <- list() reduced_valid_sets <- list()
# Parse validation datasets # Parse validation datasets
if (length(valids) > 0) { if (length(valids) > 0) {
# Loop through all validation datasets using name # Loop through all validation datasets using name
for (key in names(valids)) { for (key in names(valids)) {
# Use names to get validation datasets # Use names to get validation datasets
valid_data <- valids[[key]] valid_data <- valids[[key]]
# Check for duplicate train/validation dataset # Check for duplicate train/validation dataset
if (identical(data, valid_data)) { if (identical(data, valid_data)) {
vaild_contain_train <- TRUE vaild_contain_train <- TRUE
train_data_name <- key train_data_name <- key
next next
} }
# Update parameters, data # Update parameters, data
valid_data$update_params(params) valid_data$update_params(params)
valid_data$set_reference(data) valid_data$set_reference(data)
reduced_valid_sets[[key]] <- valid_data reduced_valid_sets[[key]] <- valid_data
} }
} }
# Add printing log callback # Add printing log callback
if (verbose > 0 && eval_freq > 0) { if (verbose > 0 && eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
# Add evaluation log callback # Add evaluation log callback
if (record && length(valids) > 0) { if (record && length(valids) > 0) {
callbacks <- add.cb(callbacks, cb.record.evaluation()) callbacks <- add.cb(callbacks, cb.record.evaluation())
} }
# Check for early stopping passed as parameter when adding early stopping callback # Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping") early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping")
if (any(names(params) %in% early_stop)) { if (any(names(params) %in% early_stop)) {
...@@ -208,83 +208,94 @@ lgb.train <- function(params = list(), ...@@ -208,83 +208,94 @@ lgb.train <- function(params = list(),
} }
} }
} }
# "Categorize" callbacks # "Categorize" callbacks
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
# Construct booster with datasets # Construct booster with datasets
booster <- Booster$new(params = params, train_set = data) booster <- Booster$new(params = params, train_set = data)
if (vaild_contain_train) { booster$set_train_data_name(train_data_name) } if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
for (key in names(reduced_valid_sets)) { for (key in names(reduced_valid_sets)) {
booster$add_valid(reduced_valid_sets[[key]], key) booster$add_valid(reduced_valid_sets[[key]], key)
} }
# Callback env # Callback env
env <- CB_ENV$new() env <- CB_ENV$new()
env$model <- booster env$model <- booster
env$begin_iteration <- begin_iteration env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
# Start training model using number of iterations to start and end with # Start training model using number of iterations to start and end with
for (i in seq.int(from = begin_iteration, to = end_iteration)) { for (i in seq.int(from = begin_iteration, to = end_iteration)) {
# Overwrite iteration in environment # Overwrite iteration in environment
env$iteration <- i env$iteration <- i
env$eval_list <- list() env$eval_list <- list()
# Loop through "pre_iter" element # Loop through "pre_iter" element
for (f in cb$pre_iter) { for (f in cb$pre_iter) {
f(env) f(env)
} }
# Update one boosting iteration # Update one boosting iteration
booster$update(fobj = fobj) booster$update(fobj = fobj)
# Prepare collection of evaluation results # Prepare collection of evaluation results
eval_list <- list() eval_list <- list()
# Collection: Has validation dataset? # Collection: Has validation dataset?
if (length(valids) > 0) { if (length(valids) > 0) {
# Validation has training dataset? # Validation has training dataset?
if (vaild_contain_train) { if (vaild_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = feval)) eval_list <- append(eval_list, booster$eval_train(feval = feval))
} }
# Has no validation dataset # Has no validation dataset
eval_list <- append(eval_list, booster$eval_valid(feval = feval)) eval_list <- append(eval_list, booster$eval_valid(feval = feval))
} }
# Write evaluation result in environment # Write evaluation result in environment
env$eval_list <- eval_list env$eval_list <- eval_list
# Loop through env # Loop through env
for (f in cb$post_iter) { for (f in cb$post_iter) {
f(env) f(env)
} }
# Check for early stopping and break if needed # Check for early stopping and break if needed
if (env$met_early_stop) break if (env$met_early_stop) break
} }
# When early stopping is not activated, we compute the best iteration / score ourselves by selecting the first metric and the first dataset
if (record && length(valids) > 0 && is.na(env$best_score)) {
if (env$eval_list[[1]]$higher_better[1] == TRUE) {
booster$best_iter <- unname(which.max(unlist(booster$record_evals[[2]][[1]][[1]])))
booster$best_score <- booster$record_evals[[2]][[1]][[1]][[booster$best_iter]]
} else {
booster$best_iter <- unname(which.min(unlist(booster$record_evals[[2]][[1]][[1]])))
booster$best_score <- booster$record_evals[[2]][[1]][[1]][[booster$best_iter]]
}
}
# Check for booster model conversion to predictor model # Check for booster model conversion to predictor model
if (reset_data) { if (reset_data) {
# Store temporarily model data elsewhere # Store temporarily model data elsewhere
booster_old <- list(best_iter = booster$best_iter, booster_old <- list(best_iter = booster$best_iter,
best_score = booster$best_score, best_score = booster$best_score,
record_evals = booster$record_evals) record_evals = booster$record_evals)
# Reload model # Reload model
booster <- lgb.load(model_str = booster$save_model_to_string()) booster <- lgb.load(model_str = booster$save_model_to_string())
booster$best_iter <- booster_old$best_iter booster$best_iter <- booster_old$best_iter
booster$best_score <- booster_old$best_score booster$best_score <- booster_old$best_score
booster$record_evals <- booster_old$record_evals booster$record_evals <- booster_old$record_evals
} }
# Return booster # Return booster
return(booster) return(booster)
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment