"include/vscode:/vscode.git/clone" did not exist on "6f0d7cc2dee9f0d71450287ecdbeeecc5d43791b"
Commit f70a0532 authored by Laurae's avatar Laurae Committed by James Lamb
Browse files

[R-package] Fix best_iter and best_score (#2159)

* Callback for NA handling

* lgb.Booster default score => NA

* lgb.cv default best score => NA

* Fix back callback

* lgb.train with booster check at the end

manual tests done: 
* With early stopping + with validation set
* With early stopping + without validation set
* Without early stopping + with validation set
* Without early stopping + without validation set

And with multiple metrics / validation sets.

* lgb.cv with booster check at the end

manual tests done: 
* With early stopping + with validation set
* With early stopping + without validation set
* Without early stopping + with validation set
* Without early stopping + without validation set

And with multiple metrics / validation sets.
parent 2459362a
......@@ -10,80 +10,80 @@ CB_ENV <- R6::R6Class(
eval_list = list(),
eval_err_list = list(),
best_iter = -1,
best_score = -1,
best_score = NA,
met_early_stop = FALSE
)
)
cb.reset.parameters <- function(new_params) {
# Check for parameter list
if (!is.list(new_params)) {
stop(sQuote("new_params"), " must be a list")
}
# Deparse parameter list
pnames <- gsub("\\.", "_", names(new_params))
nrounds <- NULL
# Run some checks in the beginning
init <- function(env) {
# Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1
# Check for model environment
if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
# Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos
not_allowed <- c("num_class", "metric", "boosting_type")
if (any(pnames %in% not_allowed)) {
stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
}
# Check parameter names
for (n in pnames) {
# Set name
p <- new_params[[n]]
# Check if function for parameter
if (is.function(p)) {
# Check if requires at least two arguments
if (length(formals(p)) != 2) {
stop("Parameter ", sQuote(n), " is a function but not of two arguments")
}
# Check if numeric or character
} else if (is.numeric(p) || is.character(p)) {
# Check if length is matching
if (length(p) != nrounds) {
stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
}
} else {
stop("Parameter ", sQuote(n), " is not a function or a vector")
}
}
}
callback <- function(env) {
# Check if rounds is null
if (is.null(nrounds)) {
init(env)
}
# Store iteration
i <- env$iteration - env$begin_iteration
# Apply list on parameters
pars <- lapply(new_params, function(p) {
if (is.function(p)) {
......@@ -91,14 +91,14 @@ cb.reset.parameters <- function(new_params) {
}
p[i]
})
# To-do check pars
if (!is.null(env$model)) {
env$model$reset_parameter(pars)
}
}
attr(callback, "call") <- match.call()
attr(callback, "is_pre_iteration") <- TRUE
attr(callback, "name") <- "cb.reset.parameters"
......@@ -107,327 +107,328 @@ cb.reset.parameters <- function(new_params) {
# Format the evaluation metric string
format.eval.string <- function(eval_res, eval_err = NULL) {
# Check for empty evaluation string
if (is.null(eval_res) || length(eval_res) == 0) {
stop("no evaluation results")
}
# Check for empty evaluation error
if (!is.null(eval_err)) {
sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
} else {
sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
}
}
merge.eval.string <- function(env) {
# Check length of evaluation list
if (length(env$eval_list) <= 0) {
return("")
}
# Get evaluation
msg <- list(sprintf("[%d]:", env$iteration))
# Set if evaluation error
is_eval_err <- length(env$eval_err_list) > 0
# Loop through evaluation list
for (j in seq_along(env$eval_list)) {
# Store evaluation error
eval_err <- NULL
if (is_eval_err) {
eval_err <- env$eval_err_list[[j]]
}
# Set error message
msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
}
# Return tabulated separated message
paste0(msg, collapse = "\t")
}
cb.print.evaluation <- function(period = 1) {
# Create callback
callback <- function(env) {
# Check if period is at least 1 or more
if (period > 0) {
# Store iteration
i <- env$iteration
# Check if iteration matches moduo
if ((i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration ))) {
# Merge evaluation string
msg <- merge.eval.string(env)
# Check if message is existing
if (nchar(msg) > 0) {
cat(merge.eval.string(env), "\n")
}
}
}
}
# Store attributes
attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.print.evaluation"
# Return callback
callback
}
cb.record.evaluation <- function() {
# Create callback
callback <- function(env) {
# Return empty if empty evaluation list
if (length(env$eval_list) <= 0) {
return()
}
# Set if evaluation error
is_eval_err <- length(env$eval_err_list) > 0
# Check length of recorded evaluation
if (length(env$model$record_evals) == 0) {
# Loop through each evaluation list element
for (j in seq_along(env$eval_list)) {
# Store names
data_name <- env$eval_list[[j]]$data_name
name <- env$eval_list[[j]]$name
env$model$record_evals$start_iter <- env$begin_iteration
# Check if evaluation record exists
if (is.null(env$model$record_evals[[data_name]])) {
env$model$record_evals[[data_name]] <- list()
}
# Create dummy lists
env$model$record_evals[[data_name]][[name]] <- list()
env$model$record_evals[[data_name]][[name]]$eval <- list()
env$model$record_evals[[data_name]][[name]]$eval_err <- list()
}
}
# Loop through each evaluation list element
for (j in seq_along(env$eval_list)) {
# Get evaluation data
eval_res <- env$eval_list[[j]]
eval_err <- NULL
if (is_eval_err) {
eval_err <- env$eval_err_list[[j]]
}
# Store names
data_name <- eval_res$data_name
name <- eval_res$name
# Store evaluation data
env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
}
}
# Store attributes
attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.record.evaluation"
# Return callback
callback
}
cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Initialize variables
factor_to_bigger_better <- NULL
best_iter <- NULL
best_score <- NULL
best_msg <- NULL
eval_len <- NULL
# Initalization function
init <- function(env) {
# Store evaluation length
eval_len <<- length(env$eval_list)
# Early stopping cannot work without metrics
if (eval_len == 0) {
stop("For early stopping, valids must have at least one element")
}
# Check if verbose or not
if (isTRUE(verbose)) {
cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
}
# Maximization or minimization task
factor_to_bigger_better <<- rep.int(1.0, eval_len)
best_iter <<- rep.int(-1, eval_len)
best_score <<- rep.int(-Inf, eval_len)
best_msg <<- list()
# Loop through evaluation elements
for (i in seq_len(eval_len)) {
# Prepend message
best_msg <<- c(best_msg, "")
# Check if maximization or minimization
if (!env$eval_list[[i]]$higher_better) {
factor_to_bigger_better[i] <<- -1.0
}
}
}
# Create callback
callback <- function(env, finalize = FALSE) {
# Check for empty evaluation
if (is.null(eval_len)) {
init(env)
}
# Store iteration
cur_iter <- env$iteration
# Loop through evaluation
for (i in seq_len(eval_len)) {
# Store score
score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
# Check if score is better
if (score > best_score[i]) {
# Store new scores
best_score[i] <<- score
best_iter[i] <<- cur_iter
# Prepare to print if verbose
if (verbose) {
best_msg[[i]] <<- as.character(merge.eval.string(env))
}
} else {
# Check if early stopping is required
if (cur_iter - best_iter[i] >= stopping_rounds) {
# Check if model is not null
if (!is.null(env$model)) {
env$model$best_score <- best_score[i]
env$model$best_iter <- best_iter[i]
# Check if score is better
if (score > best_score[i]) {
# Store new scores
best_score[i] <<- score
best_iter[i] <<- cur_iter
# Prepare to print if verbose
if (verbose) {
best_msg[[i]] <<- as.character(merge.eval.string(env))
}
# Print message if verbose
if (isTRUE(verbose)) {
cat("Early stopping, best iteration is:", "\n")
cat(best_msg[[i]], "\n")
} else {
# Check if early stopping is required
if (cur_iter - best_iter[i] >= stopping_rounds) {
# Check if model is not null
if (!is.null(env$model)) {
env$model$best_score <- best_score[i]
env$model$best_iter <- best_iter[i]
}
# Print message if verbose
if (isTRUE(verbose)) {
cat("Early stopping, best iteration is:", "\n")
cat(best_msg[[i]], "\n")
}
# Store best iteration and stop
env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE
}
# Store best iteration and stop
env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE
}
}
if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
# Check if model is not null
if (!is.null(env$model)) {
env$model$best_score <- best_score[i]
env$model$best_iter <- best_iter[i]
}
# Print message if verbose
if (isTRUE(verbose)) {
cat("Did not meet early stopping, best iteration is:", "\n")
cat(best_msg[[i]], "\n")
}
# Store best iteration and stop
env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE
}
}
}
# Set attributes
attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.early.stop"
# Return callback
callback
}
# Extract callback names from the list of callbacks
callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
add.cb <- function(cb_list, cb) {
# Combine two elements
cb_list <- c(cb_list, cb)
# Set names of elements
names(cb_list) <- callback.names(cb_list)
# Check for existence
if ("cb.early.stop" %in% names(cb_list)) {
# Concatenate existing elements
cb_list <- c(cb_list, cb_list["cb.early.stop"])
# Remove only the first one
cb_list["cb.early.stop"] <- NULL
}
# Return element
cb_list
}
categorize.callbacks <- function(cb_list) {
# Check for pre-iteration or post-iteration
list(
pre_iter = Filter(function(x) {
pre <- attr(x, "is_pre_iteration")
!is.null(pre) && pre
}, cb_list),
pre <- attr(x, "is_pre_iteration")
!is.null(pre) && pre
}, cb_list),
post_iter = Filter(function(x) {
pre <- attr(x, "is_pre_iteration")
is.null(pre) || !pre
}, cb_list)
pre <- attr(x, "is_pre_iteration")
is.null(pre) || !pre
}, cb_list)
)
}
......@@ -5,7 +5,7 @@ Booster <- R6::R6Class(
public = list(
best_iter = -1,
best_score = -1,
best_score = NA,
record_evals = list(),
# Finalize will free up the handles
......
......@@ -4,7 +4,7 @@ CVBooster <- R6::R6Class(
cloneable = FALSE,
public = list(
best_iter = -1,
best_score = -1,
best_score = NA,
record_evals = list(),
boosters = list(),
initialize = function(x) {
......@@ -90,7 +90,7 @@ lgb.cv <- function(params = list(),
callbacks = list(),
reset_data = FALSE,
...) {
# Setup temporary variables
addiction_params <- list(...)
params <- append(params, addiction_params)
......@@ -99,35 +99,35 @@ lgb.cv <- function(params = list(),
params <- lgb.check.eval(params, eval)
fobj <- NULL
feval <- NULL
if (nrounds <= 0) {
stop("nrounds should be greater than zero")
}
# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
params$objective <- "NONE"
}
# Check for loss (function or not)
if (is.function(eval)) {
feval <- eval
}
# Check for parameters
lgb.check.params(params)
# Init predictor to empty
predictor <- NULL
# Check for boosting from a trained model
if (is.character(init_model)) {
predictor <- Predictor$new(init_model)
} else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor()
}
# Set the iteration to start from / end to (and check for boosting from a trained model, again)
begin_iteration <- 1
if (!is.null(predictor)) {
......@@ -140,7 +140,7 @@ lgb.cv <- function(params = list(),
} else {
end_iteration <- begin_iteration + nrounds - 1
}
# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
if (is.null(label)) {
......@@ -148,49 +148,49 @@ lgb.cv <- function(params = list(),
}
data <- lgb.Dataset(data, label = label)
}
# Check for weights
if (!is.null(weight)) {
data$setinfo("weight", weight)
}
# Update parameters with parsed parameters
data$update_params(params)
# Create the predictor set
data$.__enclos_env__$private$set_predictor(predictor)
# Write column names
if (!is.null(colnames)) {
data$set_colnames(colnames)
}
# Write categorical features
if (!is.null(categorical_feature)) {
data$set_categorical_feature(categorical_feature)
}
# Construct datasets, if needed
data$construct()
# Check for folds
if (!is.null(folds)) {
# Check for list of folds or for single value
if (!is.list(folds) || length(folds) < 2) {
stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold")
}
# Set number of folds
nfold <- length(folds)
} else {
# Check fold value
if (nfold <= 1) {
stop(sQuote("nfold"), " must be > 1")
}
# Create folds
folds <- generate.cv.folds(nfold,
nrow(data),
......@@ -198,19 +198,19 @@ lgb.cv <- function(params = list(),
getinfo(data, "label"),
getinfo(data, "group"),
params)
}
# Add printing log callback
if (verbose > 0 && eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
}
# Add evaluation log callback
if (record) {
callbacks <- add.cb(callbacks, cb.record.evaluation())
}
# Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping")
if (any(names(params) %in% early_stop)) {
......@@ -224,10 +224,10 @@ lgb.cv <- function(params = list(),
}
}
}
# Categorize callbacks
cb <- categorize.callbacks(callbacks)
# Construct booster using a list apply, check if requires group or not
if (!is.list(folds[[1]])) {
bst_folds <- lapply(seq_along(folds), function(k) {
......@@ -256,55 +256,66 @@ lgb.cv <- function(params = list(),
list(booster = booster)
})
}
# Create new booster
cv_booster <- CVBooster$new(bst_folds)
# Callback env
env <- CB_ENV$new()
env$model <- cv_booster
env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration
# Start training model using number of iterations to start and end with
for (i in seq.int(from = begin_iteration, to = end_iteration)) {
# Overwrite iteration in environment
env$iteration <- i
env$eval_list <- list()
# Loop through "pre_iter" element
for (f in cb$pre_iter) {
f(env)
}
# Update one boosting iteration
msg <- lapply(cv_booster$boosters, function(fd) {
fd$booster$update(fobj = fobj)
fd$booster$eval_valid(feval = feval)
})
# Prepare collection of evaluation results
merged_msg <- lgb.merge.cv.result(msg)
# Write evaluation result in environment
env$eval_list <- merged_msg$eval_list
# Check for standard deviation requirement
if(showsd) {
env$eval_err_list <- merged_msg$eval_err_list
}
# Loop through env
for (f in cb$post_iter) {
f(env)
}
# Check for early stopping and break if needed
if (env$met_early_stop) break
}
if (record && is.na(env$best_score)) {
if (env$eval_list[[1]]$higher_better[1] == TRUE) {
cv_booster$best_iter <- unname(which.max(unlist(cv_booster$record_evals[[2]][[1]][[1]])))
cv_booster$best_score <- cv_booster$record_evals[[2]][[1]][[1]][[cv_booster$best_iter]]
} else {
cv_booster$best_iter <- unname(which.min(unlist(cv_booster$record_evals[[2]][[1]][[1]])))
cv_booster$best_score <- cv_booster$record_evals[[2]][[1]][[1]][[cv_booster$best_iter]]
}
}
if (reset_data) {
lapply(cv_booster$boosters, function(fd) {
# Store temporarily model data elsewhere
......@@ -318,57 +329,58 @@ lgb.cv <- function(params = list(),
fd$booster$record_evals <- booster_old$record_evals
})
}
# Return booster
return(cv_booster)
}
# Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
# Check for group existence
if (is.null(group)) {
# Shuffle
rnd_idx <- sample.int(nrows)
# Request stratified folds
if (isTRUE(stratified) && params$objective %in% c("binary", "multiclass") && length(label) == length(rnd_idx)) {
y <- label[rnd_idx]
y <- factor(y)
folds <- lgb.stratified.folds(y, nfold)
} else {
# Make simple non-stratified folds
folds <- list()
# Loop through each fold
for (i in seq_len(nfold)) {
kstep <- length(rnd_idx) %/% (nfold - i + 1)
folds[[i]] <- rnd_idx[seq_len(kstep)]
rnd_idx <- rnd_idx[-seq_len(kstep)]
}
}
} else {
# When doing group, stratified is not possible (only random selection)
if (nfold > length(group)) {
stop("\n\tYou requested too many folds for the number of available groups.\n")
}
# Degroup the groups
ungrouped <- inverse.rle(list(lengths = group, values = seq_along(group)))
# Can't stratify, shuffle
rnd_idx <- sample.int(length(group))
# Make simple non-stratified folds
folds <- list()
# Loop through each fold
for (i in seq_len(nfold)) {
kstep <- length(rnd_idx) %/% (nfold - i + 1)
......@@ -376,12 +388,12 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
group = rnd_idx[seq_len(kstep)])
rnd_idx <- rnd_idx[-seq_len(kstep)]
}
}
# Return folds
return(folds)
}
# Creates CV folds stratified by the values of y.
......@@ -389,7 +401,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
# by always returning an unnamed list of fold indices.
#' @importFrom stats quantile
lgb.stratified.folds <- function(y, k = 10) {
## Group the numeric data based on their magnitudes
## and sample within those groups.
## When the number of samples is low, we may have
......@@ -399,51 +411,51 @@ lgb.stratified.folds <- function(y, k = 10) {
## At most, we will use quantiles. If the sample
## is too small, we just do regular unstratified CV
if (is.numeric(y)) {
cuts <- length(y) %/% k
if (cuts < 2) { cuts <- 2 }
if (cuts > 5) { cuts <- 5 }
y <- cut(y,
unique(stats::quantile(y, probs = seq.int(0, 1, length.out = cuts))),
include.lowest = TRUE)
}
if (k < length(y)) {
## Reset levels so that the possible levels and
## the levels in the vector are the same
y <- factor(as.character(y))
numInClass <- table(y)
foldVector <- vector(mode = "integer", length(y))
## For each class, balance the fold allocation as far
## as possible, then resample the remainder.
## The final assignment of folds is also randomized.
for (i in seq_along(numInClass)) {
## Create a vector of integers from 1:k as many times as possible without
## going over the number of samples in the class. Note that if the number
## of samples in a class is less than k, nothing is producd here.
seqVector <- rep(seq_len(k), numInClass[i] %/% k)
## Add enough random integers to get length(seqVector) == numInClass[i]
if (numInClass[i] %% k > 0) {
seqVector <- c(seqVector, sample.int(k, numInClass[i] %% k))
}
## Shuffle the integers for fold assignment and assign to this classes's data
foldVector[y == dimnames(numInClass)$y[i]] <- sample(seqVector)
}
} else {
foldVector <- seq(along = y)
}
# Return data
out <- split(seq(along = y), foldVector)
names(out) <- NULL
......@@ -451,53 +463,53 @@ lgb.stratified.folds <- function(y, k = 10) {
}
lgb.merge.cv.result <- function(msg, showsd = TRUE) {
# Get CV message length
if (length(msg) == 0) {
stop("lgb.cv: size of cv result error")
}
# Get evaluation message length
eval_len <- length(msg[[1]])
# Is evaluation message empty?
if (eval_len == 0) {
stop("lgb.cv: should provide at least one metric for CV")
}
# Get evaluation results using a list apply
eval_result <- lapply(seq_len(eval_len), function(j) {
as.numeric(lapply(seq_along(msg), function(i) {
msg[[i]][[j]]$value }))
})
# Get evaluation
ret_eval <- msg[[1]]
# Go through evaluation length items
for (j in seq_len(eval_len)) {
ret_eval[[j]]$value <- mean(eval_result[[j]])
}
# Preinit evaluation error
ret_eval_err <- NULL
# Check for standard deviation
if (showsd) {
# Parse standard deviation
for (j in seq_len(eval_len)) {
ret_eval_err <- c(ret_eval_err,
sqrt(mean(eval_result[[j]] ^ 2) - mean(eval_result[[j]]) ^ 2))
}
# Convert to list
ret_eval_err <- as.list(ret_eval_err)
}
# Return errors
list(eval_list = ret_eval,
eval_err_list = ret_eval_err)
}
......@@ -62,7 +62,7 @@ lgb.train <- function(params = list(),
callbacks = list(),
reset_data = FALSE,
...) {
# Setup temporary variables
additional_params <- list(...)
params <- append(params, additional_params)
......@@ -71,35 +71,35 @@ lgb.train <- function(params = list(),
params <- lgb.check.eval(params, eval)
fobj <- NULL
feval <- NULL
if (nrounds <= 0) {
stop("nrounds should be greater than zero")
}
# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
params$objective <- "NONE"
}
# Check for loss (function or not)
if (is.function(eval)) {
feval <- eval
}
# Check for parameters
lgb.check.params(params)
# Init predictor to empty
predictor <- NULL
# Check for boosting from a trained model
if (is.character(init_model)) {
predictor <- Predictor$new(init_model)
} else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor()
}
# Set the iteration to start from / end to (and check for boosting from a trained model, again)
begin_iteration <- 1
if (!is.null(predictor)) {
......@@ -112,89 +112,89 @@ lgb.train <- function(params = list(),
} else {
end_iteration <- begin_iteration + nrounds - 1
}
# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object")
}
# Check for validation dataset type correctness
if (length(valids) > 0) {
# One or more validation dataset
# Check for list as input and type correctness by object
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements")
}
# Attempt to get names
evnames <- names(valids)
# Check for names existance
if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of the valids must have a name tag")
}
}
# Update parameters with parsed parameters
data$update_params(params)
# Create the predictor set
data$.__enclos_env__$private$set_predictor(predictor)
# Write column names
if (!is.null(colnames)) {
data$set_colnames(colnames)
}
# Write categorical features
if (!is.null(categorical_feature)) {
data$set_categorical_feature(categorical_feature)
}
# Construct datasets, if needed
data$construct()
vaild_contain_train <- FALSE
train_data_name <- "train"
reduced_valid_sets <- list()
# Parse validation datasets
if (length(valids) > 0) {
# Loop through all validation datasets using name
for (key in names(valids)) {
# Use names to get validation datasets
valid_data <- valids[[key]]
# Check for duplicate train/validation dataset
if (identical(data, valid_data)) {
vaild_contain_train <- TRUE
train_data_name <- key
next
}
# Update parameters, data
valid_data$update_params(params)
valid_data$set_reference(data)
reduced_valid_sets[[key]] <- valid_data
}
}
# Add printing log callback
if (verbose > 0 && eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
}
# Add evaluation log callback
if (record && length(valids) > 0) {
callbacks <- add.cb(callbacks, cb.record.evaluation())
}
# Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping")
if (any(names(params) %in% early_stop)) {
......@@ -208,83 +208,94 @@ lgb.train <- function(params = list(),
}
}
}
# "Categorize" callbacks
cb <- categorize.callbacks(callbacks)
# Construct booster with datasets
booster <- Booster$new(params = params, train_set = data)
if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
for (key in names(reduced_valid_sets)) {
booster$add_valid(reduced_valid_sets[[key]], key)
}
# Callback env
env <- CB_ENV$new()
env$model <- booster
env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration
# Start training model using number of iterations to start and end with
for (i in seq.int(from = begin_iteration, to = end_iteration)) {
# Overwrite iteration in environment
env$iteration <- i
env$eval_list <- list()
# Loop through "pre_iter" element
for (f in cb$pre_iter) {
f(env)
}
# Update one boosting iteration
booster$update(fobj = fobj)
# Prepare collection of evaluation results
eval_list <- list()
# Collection: Has validation dataset?
if (length(valids) > 0) {
# Validation has training dataset?
if (vaild_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = feval))
}
# Has no validation dataset
eval_list <- append(eval_list, booster$eval_valid(feval = feval))
}
# Write evaluation result in environment
env$eval_list <- eval_list
# Loop through env
for (f in cb$post_iter) {
f(env)
}
# Check for early stopping and break if needed
if (env$met_early_stop) break
}
# When early stopping is not activated, we compute the best iteration / score ourselves by selecting the first metric and the first dataset
if (record && length(valids) > 0 && is.na(env$best_score)) {
if (env$eval_list[[1]]$higher_better[1] == TRUE) {
booster$best_iter <- unname(which.max(unlist(booster$record_evals[[2]][[1]][[1]])))
booster$best_score <- booster$record_evals[[2]][[1]][[1]][[booster$best_iter]]
} else {
booster$best_iter <- unname(which.min(unlist(booster$record_evals[[2]][[1]][[1]])))
booster$best_score <- booster$record_evals[[2]][[1]][[1]][[booster$best_iter]]
}
}
# Check for booster model conversion to predictor model
if (reset_data) {
# Store temporarily model data elsewhere
booster_old <- list(best_iter = booster$best_iter,
best_score = booster$best_score,
record_evals = booster$record_evals)
# Reload model
booster <- lgb.load(model_str = booster$save_model_to_string())
booster$best_iter <- booster_old$best_iter
booster$best_score <- booster_old$best_score
booster$record_evals <- booster_old$record_evals
}
# Return booster
return(booster)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment