"include/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "1444a748d7f1255a52e341e7c766822d51b7b292"
Commit 5457ef6b authored by James Lamb's avatar James Lamb Committed by Guolin Ke
Browse files

[R-package] removed horizontal whitespace (fixes #1642) (#1651)

* [R-package] removed horizontal whitespace (fixes #1642)

* [R-package] fixed missing newline in test file
parent f44b60b6
...@@ -16,74 +16,74 @@ CB_ENV <- R6::R6Class( ...@@ -16,74 +16,74 @@ CB_ENV <- R6::R6Class(
) )
cb.reset.parameters <- function(new_params) { cb.reset.parameters <- function(new_params) {
# Check for parameter list # Check for parameter list
if (!is.list(new_params)) { if (!is.list(new_params)) {
stop(sQuote("new_params"), " must be a list") stop(sQuote("new_params"), " must be a list")
} }
# Deparse parameter list # Deparse parameter list
pnames <- gsub("\\.", "_", names(new_params)) pnames <- gsub("\\.", "_", names(new_params))
nrounds <- NULL nrounds <- NULL
# Run some checks in the beginning # Run some checks in the beginning
init <- function(env) { init <- function(env) {
# Store boosting rounds # Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1 nrounds <<- env$end_iteration - env$begin_iteration + 1
# Check for model environment # Check for model environment
if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) } if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
# Some parameters are not allowed to be changed, # Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos # since changing them would simply wreck some chaos
not_allowed <- c("num_class", "metric", "boosting_type") not_allowed <- c("num_class", "metric", "boosting_type")
if (any(pnames %in% not_allowed)) { if (any(pnames %in% not_allowed)) {
stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting") stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
} }
# Check parameter names # Check parameter names
for (n in pnames) { for (n in pnames) {
# Set name # Set name
p <- new_params[[n]] p <- new_params[[n]]
# Check if function for parameter # Check if function for parameter
if (is.function(p)) { if (is.function(p)) {
# Check if requires at least two arguments # Check if requires at least two arguments
if (length(formals(p)) != 2) { if (length(formals(p)) != 2) {
stop("Parameter ", sQuote(n), " is a function but not of two arguments") stop("Parameter ", sQuote(n), " is a function but not of two arguments")
} }
# Check if numeric or character # Check if numeric or character
} else if (is.numeric(p) || is.character(p)) { } else if (is.numeric(p) || is.character(p)) {
# Check if length is matching # Check if length is matching
if (length(p) != nrounds) { if (length(p) != nrounds) {
stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds")) stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
} }
} else { } else {
stop("Parameter ", sQuote(n), " is not a function or a vector") stop("Parameter ", sQuote(n), " is not a function or a vector")
} }
} }
} }
callback <- function(env) { callback <- function(env) {
# Check if rounds is null # Check if rounds is null
if (is.null(nrounds)) { if (is.null(nrounds)) {
init(env) init(env)
} }
# Store iteration # Store iteration
i <- env$iteration - env$begin_iteration i <- env$iteration - env$begin_iteration
# Apply list on parameters # Apply list on parameters
pars <- lapply(new_params, function(p) { pars <- lapply(new_params, function(p) {
if (is.function(p)) { if (is.function(p)) {
...@@ -91,14 +91,14 @@ cb.reset.parameters <- function(new_params) { ...@@ -91,14 +91,14 @@ cb.reset.parameters <- function(new_params) {
} }
p[i] p[i]
}) })
# To-do check pars # To-do check pars
if (!is.null(env$model)) { if (!is.null(env$model)) {
env$model$reset_parameter(pars) env$model$reset_parameter(pars)
} }
} }
attr(callback, "call") <- match.call() attr(callback, "call") <- match.call()
attr(callback, "is_pre_iteration") <- TRUE attr(callback, "is_pre_iteration") <- TRUE
attr(callback, "name") <- "cb.reset.parameters" attr(callback, "name") <- "cb.reset.parameters"
...@@ -107,258 +107,258 @@ cb.reset.parameters <- function(new_params) { ...@@ -107,258 +107,258 @@ cb.reset.parameters <- function(new_params) {
# Format the evaluation metric string # Format the evaluation metric string
format.eval.string <- function(eval_res, eval_err = NULL) { format.eval.string <- function(eval_res, eval_err = NULL) {
# Check for empty evaluation string # Check for empty evaluation string
if (is.null(eval_res) || length(eval_res) == 0) { if (is.null(eval_res) || length(eval_res) == 0) {
stop("no evaluation results") stop("no evaluation results")
} }
# Check for empty evaluation error # Check for empty evaluation error
if (!is.null(eval_err)) { if (!is.null(eval_err)) {
sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err) sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err)
} else { } else {
sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value) sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value)
} }
} }
merge.eval.string <- function(env) { merge.eval.string <- function(env) {
# Check length of evaluation list # Check length of evaluation list
if (length(env$eval_list) <= 0) { if (length(env$eval_list) <= 0) {
return("") return("")
} }
# Get evaluation # Get evaluation
msg <- list(sprintf("[%d]:", env$iteration)) msg <- list(sprintf("[%d]:", env$iteration))
# Set if evaluation error # Set if evaluation error
is_eval_err <- length(env$eval_err_list) > 0 is_eval_err <- length(env$eval_err_list) > 0
# Loop through evaluation list # Loop through evaluation list
for (j in seq_along(env$eval_list)) { for (j in seq_along(env$eval_list)) {
# Store evaluation error # Store evaluation error
eval_err <- NULL eval_err <- NULL
if (is_eval_err) { if (is_eval_err) {
eval_err <- env$eval_err_list[[j]] eval_err <- env$eval_err_list[[j]]
} }
# Set error message # Set error message
msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err)) msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
} }
# Return tabulated separated message # Return tabulated separated message
paste0(msg, collapse = "\t") paste0(msg, collapse = "\t")
} }
cb.print.evaluation <- function(period = 1) { cb.print.evaluation <- function(period = 1) {
# Create callback # Create callback
callback <- function(env) { callback <- function(env) {
# Check if period is at least 1 or more # Check if period is at least 1 or more
if (period > 0) { if (period > 0) {
# Store iteration # Store iteration
i <- env$iteration i <- env$iteration
# Check if iteration matches moduo # Check if iteration matches moduo
if ((i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration ))) { if ((i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration ))) {
# Merge evaluation string # Merge evaluation string
msg <- merge.eval.string(env) msg <- merge.eval.string(env)
# Check if message is existing # Check if message is existing
if (nchar(msg) > 0) { if (nchar(msg) > 0) {
cat(merge.eval.string(env), "\n") cat(merge.eval.string(env), "\n")
} }
} }
} }
} }
# Store attributes # Store attributes
attr(callback, "call") <- match.call() attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.print.evaluation" attr(callback, "name") <- "cb.print.evaluation"
# Return callback # Return callback
callback callback
} }
cb.record.evaluation <- function() { cb.record.evaluation <- function() {
# Create callback # Create callback
callback <- function(env) { callback <- function(env) {
# Return empty if empty evaluation list # Return empty if empty evaluation list
if (length(env$eval_list) <= 0) { if (length(env$eval_list) <= 0) {
return() return()
} }
# Set if evaluation error # Set if evaluation error
is_eval_err <- length(env$eval_err_list) > 0 is_eval_err <- length(env$eval_err_list) > 0
# Check length of recorded evaluation # Check length of recorded evaluation
if (length(env$model$record_evals) == 0) { if (length(env$model$record_evals) == 0) {
# Loop through each evaluation list element # Loop through each evaluation list element
for (j in seq_along(env$eval_list)) { for (j in seq_along(env$eval_list)) {
# Store names # Store names
data_name <- env$eval_list[[j]]$data_name data_name <- env$eval_list[[j]]$data_name
name <- env$eval_list[[j]]$name name <- env$eval_list[[j]]$name
env$model$record_evals$start_iter <- env$begin_iteration env$model$record_evals$start_iter <- env$begin_iteration
# Check if evaluation record exists # Check if evaluation record exists
if (is.null(env$model$record_evals[[data_name]])) { if (is.null(env$model$record_evals[[data_name]])) {
env$model$record_evals[[data_name]] <- list() env$model$record_evals[[data_name]] <- list()
} }
# Create dummy lists # Create dummy lists
env$model$record_evals[[data_name]][[name]] <- list() env$model$record_evals[[data_name]][[name]] <- list()
env$model$record_evals[[data_name]][[name]]$eval <- list() env$model$record_evals[[data_name]][[name]]$eval <- list()
env$model$record_evals[[data_name]][[name]]$eval_err <- list() env$model$record_evals[[data_name]][[name]]$eval_err <- list()
} }
} }
# Loop through each evaluation list element # Loop through each evaluation list element
for (j in seq_along(env$eval_list)) { for (j in seq_along(env$eval_list)) {
# Get evaluation data # Get evaluation data
eval_res <- env$eval_list[[j]] eval_res <- env$eval_list[[j]]
eval_err <- NULL eval_err <- NULL
if (is_eval_err) { if (is_eval_err) {
eval_err <- env$eval_err_list[[j]] eval_err <- env$eval_err_list[[j]]
} }
# Store names # Store names
data_name <- eval_res$data_name data_name <- eval_res$data_name
name <- eval_res$name name <- eval_res$name
# Store evaluation data # Store evaluation data
env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value) env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err) env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
} }
} }
# Store attributes # Store attributes
attr(callback, "call") <- match.call() attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.record.evaluation" attr(callback, "name") <- "cb.record.evaluation"
# Return callback # Return callback
callback callback
} }
cb.early.stop <- function(stopping_rounds, verbose = TRUE) { cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Initialize variables # Initialize variables
factor_to_bigger_better <- NULL factor_to_bigger_better <- NULL
best_iter <- NULL best_iter <- NULL
best_score <- NULL best_score <- NULL
best_msg <- NULL best_msg <- NULL
eval_len <- NULL eval_len <- NULL
# Initalization function # Initalization function
init <- function(env) { init <- function(env) {
# Store evaluation length # Store evaluation length
eval_len <<- length(env$eval_list) eval_len <<- length(env$eval_list)
# Early stopping cannot work without metrics # Early stopping cannot work without metrics
if (eval_len == 0) { if (eval_len == 0) {
stop("For early stopping, valids must have at least one element") stop("For early stopping, valids must have at least one element")
} }
# Check if verbose or not # Check if verbose or not
if (isTRUE(verbose)) { if (isTRUE(verbose)) {
cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "") cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
} }
# Maximization or minimization task # Maximization or minimization task
factor_to_bigger_better <<- rep.int(1.0, eval_len) factor_to_bigger_better <<- rep.int(1.0, eval_len)
best_iter <<- rep.int(-1, eval_len) best_iter <<- rep.int(-1, eval_len)
best_score <<- rep.int(-Inf, eval_len) best_score <<- rep.int(-Inf, eval_len)
best_msg <<- list() best_msg <<- list()
# Loop through evaluation elements # Loop through evaluation elements
for (i in seq_len(eval_len)) { for (i in seq_len(eval_len)) {
# Prepend message # Prepend message
best_msg <<- c(best_msg, "") best_msg <<- c(best_msg, "")
# Check if maximization or minimization # Check if maximization or minimization
if (!env$eval_list[[i]]$higher_better) { if (!env$eval_list[[i]]$higher_better) {
factor_to_bigger_better[i] <<- -1.0 factor_to_bigger_better[i] <<- -1.0
} }
} }
} }
# Create callback # Create callback
callback <- function(env, finalize = FALSE) { callback <- function(env, finalize = FALSE) {
# Check for empty evaluation # Check for empty evaluation
if (is.null(eval_len)) { if (is.null(eval_len)) {
init(env) init(env)
} }
# Store iteration # Store iteration
cur_iter <- env$iteration cur_iter <- env$iteration
# Loop through evaluation # Loop through evaluation
for (i in seq_len(eval_len)) { for (i in seq_len(eval_len)) {
# Store score # Store score
score <- env$eval_list[[i]]$value * factor_to_bigger_better[i] score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
# Check if score is better # Check if score is better
if (score > best_score[i]) { if (score > best_score[i]) {
# Store new scores # Store new scores
best_score[i] <<- score best_score[i] <<- score
best_iter[i] <<- cur_iter best_iter[i] <<- cur_iter
# Prepare to print if verbose # Prepare to print if verbose
if (verbose) { if (verbose) {
best_msg[[i]] <<- as.character(merge.eval.string(env)) best_msg[[i]] <<- as.character(merge.eval.string(env))
} }
} else { } else {
# Check if early stopping is required # Check if early stopping is required
if (cur_iter - best_iter[i] >= stopping_rounds) { if (cur_iter - best_iter[i] >= stopping_rounds) {
# Check if model is not null # Check if model is not null
if (!is.null(env$model)) { if (!is.null(env$model)) {
env$model$best_score <- best_score[i] env$model$best_score <- best_score[i]
env$model$best_iter <- best_iter[i] env$model$best_iter <- best_iter[i]
} }
# Print message if verbose # Print message if verbose
if (isTRUE(verbose)) { if (isTRUE(verbose)) {
cat("Early stopping, best iteration is:", "\n") cat("Early stopping, best iteration is:", "\n")
cat(best_msg[[i]], "\n") cat(best_msg[[i]], "\n")
} }
# Store best iteration and stop # Store best iteration and stop
env$best_iter <- best_iter[i] env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE env$met_early_stop <- TRUE
} }
} }
if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) { if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) {
# Check if model is not null # Check if model is not null
...@@ -366,58 +366,58 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) { ...@@ -366,58 +366,58 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
env$model$best_score <- best_score[i] env$model$best_score <- best_score[i]
env$model$best_iter <- best_iter[i] env$model$best_iter <- best_iter[i]
} }
# Print message if verbose # Print message if verbose
if (isTRUE(verbose)) { if (isTRUE(verbose)) {
cat("Did not meet early stopping, best iteration is:", "\n") cat("Did not meet early stopping, best iteration is:", "\n")
cat(best_msg[[i]], "\n") cat(best_msg[[i]], "\n")
} }
# Store best iteration and stop # Store best iteration and stop
env$best_iter <- best_iter[i] env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE env$met_early_stop <- TRUE
} }
} }
} }
# Set attributes # Set attributes
attr(callback, "call") <- match.call() attr(callback, "call") <- match.call()
attr(callback, "name") <- "cb.early.stop" attr(callback, "name") <- "cb.early.stop"
# Return callback # Return callback
callback callback
} }
# Extract callback names from the list of callbacks # Extract callback names from the list of callbacks
callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) } callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
add.cb <- function(cb_list, cb) { add.cb <- function(cb_list, cb) {
# Combine two elements # Combine two elements
cb_list <- c(cb_list, cb) cb_list <- c(cb_list, cb)
# Set names of elements # Set names of elements
names(cb_list) <- callback.names(cb_list) names(cb_list) <- callback.names(cb_list)
# Check for existence # Check for existence
if ("cb.early.stop" %in% names(cb_list)) { if ("cb.early.stop" %in% names(cb_list)) {
# Concatenate existing elements # Concatenate existing elements
cb_list <- c(cb_list, cb_list["cb.early.stop"]) cb_list <- c(cb_list, cb_list["cb.early.stop"])
# Remove only the first one # Remove only the first one
cb_list["cb.early.stop"] <- NULL cb_list["cb.early.stop"] <- NULL
} }
# Return element # Return element
cb_list cb_list
} }
categorize.callbacks <- function(cb_list) { categorize.callbacks <- function(cb_list) {
# Check for pre-iteration or post-iteration # Check for pre-iteration or post-iteration
list( list(
pre_iter = Filter(function(x) { pre_iter = Filter(function(x) {
...@@ -429,5 +429,5 @@ categorize.callbacks <- function(cb_list) { ...@@ -429,5 +429,5 @@ categorize.callbacks <- function(cb_list) {
is.null(pre) || !pre is.null(pre) || !pre
}, cb_list) }, cb_list)
) )
} }
...@@ -3,110 +3,110 @@ Booster <- R6::R6Class( ...@@ -3,110 +3,110 @@ Booster <- R6::R6Class(
classname = "lgb.Booster", classname = "lgb.Booster",
cloneable = FALSE, cloneable = FALSE,
public = list( public = list(
best_iter = -1, best_iter = -1,
best_score = -1, best_score = -1,
record_evals = list(), record_evals = list(),
# Finalize will free up the handles # Finalize will free up the handles
finalize = function() { finalize = function() {
# Check the need for freeing handle # Check the need for freeing handle
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
# Freeing up handle # Freeing up handle
lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle) lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
# Initialize will create a starter booster # Initialize will create a starter booster
initialize = function(params = list(), initialize = function(params = list(),
train_set = NULL, train_set = NULL,
modelfile = NULL, modelfile = NULL,
model_str = NULL, model_str = NULL,
...) { ...) {
# Create parameters and handle # Create parameters and handle
params <- append(params, list(...)) params <- append(params, list(...))
params_str <- lgb.params2str(params) params_str <- lgb.params2str(params)
handle <- 0.0 handle <- 0.0
# Attempts to create a handle for the dataset # Attempts to create a handle for the dataset
try({ try({
# Check if training dataset is not null # Check if training dataset is not null
if (!is.null(train_set)) { if (!is.null(train_set)) {
# Check if training dataset is lgb.Dataset or not # Check if training dataset is lgb.Dataset or not
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
stop("lgb.Booster: Can only use lgb.Dataset as training data") stop("lgb.Booster: Can only use lgb.Dataset as training data")
} }
# Store booster handle # Store booster handle
handle <- lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str) handle <- lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str)
# Create private booster information # Create private booster information
private$train_set <- train_set private$train_set <- train_set
private$num_dataset <- 1 private$num_dataset <- 1
private$init_predictor <- train_set$.__enclos_env__$private$predictor private$init_predictor <- train_set$.__enclos_env__$private$predictor
# Check if predictor is existing # Check if predictor is existing
if (!is.null(private$init_predictor)) { if (!is.null(private$init_predictor)) {
# Merge booster # Merge booster
lgb.call("LGBM_BoosterMerge_R", lgb.call("LGBM_BoosterMerge_R",
ret = NULL, ret = NULL,
handle, handle,
private$init_predictor$.__enclos_env__$private$handle) private$init_predictor$.__enclos_env__$private$handle)
} }
# Check current iteration # Check current iteration
private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE) private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
} else if (!is.null(modelfile)) { } else if (!is.null(modelfile)) {
# Do we have a model file as character? # Do we have a model file as character?
if (!is.character(modelfile)) { if (!is.character(modelfile)) {
stop("lgb.Booster: Can only use a string as model file path") stop("lgb.Booster: Can only use a string as model file path")
} }
# Create booster from model # Create booster from model
handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R",
ret = handle, ret = handle,
lgb.c_str(modelfile)) lgb.c_str(modelfile))
} else if (!is.null(model_str)) { } else if (!is.null(model_str)) {
# Do we have a model_str as character? # Do we have a model_str as character?
if (!is.character(model_str)) { if (!is.character(model_str)) {
stop("lgb.Booster: Can only use a string as model_str") stop("lgb.Booster: Can only use a string as model_str")
} }
# Create booster from model # Create booster from model
handle <- lgb.call("LGBM_BoosterLoadModelFromString_R", handle <- lgb.call("LGBM_BoosterLoadModelFromString_R",
ret = handle, ret = handle,
lgb.c_str(model_str)) lgb.c_str(model_str))
} else { } else {
# Booster non existent # Booster non existent
stop("lgb.Booster: Need at least either training dataset, model file, or model_str to create booster instance") stop("lgb.Booster: Need at least either training dataset, model file, or model_str to create booster instance")
} }
}) })
# Check whether the handle was created properly if it was not stopped earlier by a stop call # Check whether the handle was created properly if it was not stopped earlier by a stop call
if (lgb.is.null.handle(handle)) { if (lgb.is.null.handle(handle)) {
stop("lgb.Booster: cannot create Booster handle") stop("lgb.Booster: cannot create Booster handle")
} else { } else {
# Create class # Create class
class(handle) <- "lgb.Booster.handle" class(handle) <- "lgb.Booster.handle"
private$handle <- handle private$handle <- handle
...@@ -114,100 +114,100 @@ Booster <- R6::R6Class( ...@@ -114,100 +114,100 @@ Booster <- R6::R6Class(
private$num_class <- lgb.call("LGBM_BoosterGetNumClasses_R", private$num_class <- lgb.call("LGBM_BoosterGetNumClasses_R",
ret = private$num_class, ret = private$num_class,
private$handle) private$handle)
} }
}, },
# Set training data name # Set training data name
set_train_data_name = function(name) { set_train_data_name = function(name) {
# Set name # Set name
private$name_train_set <- name private$name_train_set <- name
return(invisible(self)) return(invisible(self))
}, },
# Add validation data # Add validation data
add_valid = function(data, name) { add_valid = function(data, name) {
# Check if data is lgb.Dataset # Check if data is lgb.Dataset
if (!lgb.check.r6.class(data, "lgb.Dataset")) { if (!lgb.check.r6.class(data, "lgb.Dataset")) {
stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data") stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
} }
# Check if predictors are identical # Check if predictors are identical
if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) { if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
stop("lgb.Booster.add_valid: Failed to add validation data; you should use the same predictor for these data") stop("lgb.Booster.add_valid: Failed to add validation data; you should use the same predictor for these data")
} }
# Check if names are character # Check if names are character
if (!is.character(name)) { if (!is.character(name)) {
stop("lgb.Booster.add_valid: Can only use characters as data name") stop("lgb.Booster.add_valid: Can only use characters as data name")
} }
# Add validation data to booster # Add validation data to booster
lgb.call("LGBM_BoosterAddValidData_R", lgb.call("LGBM_BoosterAddValidData_R",
ret = NULL, ret = NULL,
private$handle, private$handle,
data$.__enclos_env__$private$get_handle()) data$.__enclos_env__$private$get_handle())
# Store private information # Store private information
private$valid_sets <- c(private$valid_sets, data) private$valid_sets <- c(private$valid_sets, data)
private$name_valid_sets <- c(private$name_valid_sets, name) private$name_valid_sets <- c(private$name_valid_sets, name)
private$num_dataset <- private$num_dataset + 1 private$num_dataset <- private$num_dataset + 1
private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE) private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
# Return self # Return self
return(invisible(self)) return(invisible(self))
}, },
# Reset parameters of booster # Reset parameters of booster
reset_parameter = function(params, ...) { reset_parameter = function(params, ...) {
# Append parameters # Append parameters
params <- append(params, list(...)) params <- append(params, list(...))
params_str <- lgb.params2str(params) params_str <- lgb.params2str(params)
# Reset parameters # Reset parameters
lgb.call("LGBM_BoosterResetParameter_R", lgb.call("LGBM_BoosterResetParameter_R",
ret = NULL, ret = NULL,
private$handle, private$handle,
params_str) params_str)
# Return self # Return self
return(invisible(self)) return(invisible(self))
}, },
# Perform boosting update iteration # Perform boosting update iteration
update = function(train_set = NULL, fobj = NULL) { update = function(train_set = NULL, fobj = NULL) {
# Check if training set is not null # Check if training set is not null
if (!is.null(train_set)) { if (!is.null(train_set)) {
# Check if training set is lgb.Dataset # Check if training set is lgb.Dataset
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
stop("lgb.Booster.update: Only can use lgb.Dataset as training data") stop("lgb.Booster.update: Only can use lgb.Dataset as training data")
} }
# Check if predictors are identical # Check if predictors are identical
if (!identical(train_set$predictor, private$init_predictor)) { if (!identical(train_set$predictor, private$init_predictor)) {
stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data") stop("lgb.Booster.update: Change train_set failed, you should use the same predictor for these data")
} }
# Reset training data on booster # Reset training data on booster
lgb.call("LGBM_BoosterResetTrainingData_R", lgb.call("LGBM_BoosterResetTrainingData_R",
ret = NULL, ret = NULL,
private$handle, private$handle,
train_set$.__enclos_env__$private$get_handle()) train_set$.__enclos_env__$private$get_handle())
# Store private train set # Store private train set
private$train_set = train_set private$train_set = train_set
} }
# Check if objective is empty # Check if objective is empty
if (is.null(fobj)) { if (is.null(fobj)) {
if (private$set_objective_to_none) { if (private$set_objective_to_none) {
...@@ -215,9 +215,9 @@ Booster <- R6::R6Class( ...@@ -215,9 +215,9 @@ Booster <- R6::R6Class(
} }
# Boost iteration from known objective # Boost iteration from known objective
ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle) ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
} else { } else {
# Check if objective is function # Check if objective is function
if (!is.function(fobj)) { if (!is.function(fobj)) {
stop("lgb.Booster.update: fobj should be a function") stop("lgb.Booster.update: fobj should be a function")
...@@ -228,13 +228,13 @@ Booster <- R6::R6Class( ...@@ -228,13 +228,13 @@ Booster <- R6::R6Class(
} }
# Perform objective calculation # Perform objective calculation
gpair <- fobj(private$inner_predict(1), private$train_set) gpair <- fobj(private$inner_predict(1), private$train_set)
# Check for gradient and hessian as list # Check for gradient and hessian as list
if(is.null(gpair$grad) || is.null(gpair$hess)){ if(is.null(gpair$grad) || is.null(gpair$hess)){
stop("lgb.Booster.update: custom objective should stop("lgb.Booster.update: custom objective should
return a list with attributes (hess, grad)") return a list with attributes (hess, grad)")
} }
# Return custom boosting gradient/hessian # Return custom boosting gradient/hessian
ret <- lgb.call("LGBM_BoosterUpdateOneIterCustom_R", ret <- lgb.call("LGBM_BoosterUpdateOneIterCustom_R",
ret = NULL, ret = NULL,
...@@ -242,170 +242,170 @@ Booster <- R6::R6Class( ...@@ -242,170 +242,170 @@ Booster <- R6::R6Class(
gpair$grad, gpair$grad,
gpair$hess, gpair$hess,
length(gpair$grad)) length(gpair$grad))
} }
# Loop through each iteration # Loop through each iteration
for (i in seq_along(private$is_predicted_cur_iter)) { for (i in seq_along(private$is_predicted_cur_iter)) {
private$is_predicted_cur_iter[[i]] <- FALSE private$is_predicted_cur_iter[[i]] <- FALSE
} }
return(ret) return(ret)
}, },
# Return one iteration behind # Return one iteration behind
rollback_one_iter = function() { rollback_one_iter = function() {
# Return one iteration behind # Return one iteration behind
lgb.call("LGBM_BoosterRollbackOneIter_R", lgb.call("LGBM_BoosterRollbackOneIter_R",
ret = NULL, ret = NULL,
private$handle) private$handle)
# Loop through each iteration # Loop through each iteration
for (i in seq_along(private$is_predicted_cur_iter)) { for (i in seq_along(private$is_predicted_cur_iter)) {
private$is_predicted_cur_iter[[i]] <- FALSE private$is_predicted_cur_iter[[i]] <- FALSE
} }
# Return self # Return self
return(invisible(self)) return(invisible(self))
}, },
# Get current iteration # Get current iteration
current_iter = function() { current_iter = function() {
cur_iter <- 0L cur_iter <- 0L
lgb.call("LGBM_BoosterGetCurrentIteration_R", lgb.call("LGBM_BoosterGetCurrentIteration_R",
ret = cur_iter, ret = cur_iter,
private$handle) private$handle)
}, },
# Evaluate data on metrics # Evaluate data on metrics
eval = function(data, name, feval = NULL) { eval = function(data, name, feval = NULL) {
# Check if dataset is lgb.Dataset # Check if dataset is lgb.Dataset
if (!lgb.check.r6.class(data, "lgb.Dataset")) { if (!lgb.check.r6.class(data, "lgb.Dataset")) {
stop("lgb.Booster.eval: Can only use lgb.Dataset to eval") stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
} }
# Check for identical data # Check for identical data
data_idx <- 0 data_idx <- 0
if (identical(data, private$train_set)) { if (identical(data, private$train_set)) {
data_idx <- 1 data_idx <- 1
} else { } else {
# Check for validation data # Check for validation data
if (length(private$valid_sets) > 0) { if (length(private$valid_sets) > 0) {
# Loop through each validation set # Loop through each validation set
for (i in seq_along(private$valid_sets)) { for (i in seq_along(private$valid_sets)) {
# Check for identical validation data with training data # Check for identical validation data with training data
if (identical(data, private$valid_sets[[i]])) { if (identical(data, private$valid_sets[[i]])) {
# Found identical data, skip # Found identical data, skip
data_idx <- i + 1 data_idx <- i + 1
break break
} }
} }
} }
} }
# Check if evaluation was not done # Check if evaluation was not done
if (data_idx == 0) { if (data_idx == 0) {
# Add validation data by name # Add validation data by name
self$add_valid(data, name) self$add_valid(data, name)
data_idx <- private$num_dataset data_idx <- private$num_dataset
} }
# Evaluate data # Evaluate data
private$inner_eval(name, data_idx, feval) private$inner_eval(name, data_idx, feval)
}, },
# Evaluation training data # Evaluation training data
eval_train = function(feval = NULL) { eval_train = function(feval = NULL) {
private$inner_eval(private$name_train_set, 1, feval) private$inner_eval(private$name_train_set, 1, feval)
}, },
# Evaluation validation data # Evaluation validation data
eval_valid = function(feval = NULL) { eval_valid = function(feval = NULL) {
# Create ret list # Create ret list
ret = list() ret = list()
# Check if validation is empty # Check if validation is empty
if (length(private$valid_sets) <= 0) { if (length(private$valid_sets) <= 0) {
return(ret) return(ret)
} }
# Loop through each validation set # Loop through each validation set
for (i in seq_along(private$valid_sets)) { for (i in seq_along(private$valid_sets)) {
ret <- append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval)) ret <- append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
} }
# Return ret # Return ret
return(ret) return(ret)
}, },
# Save model # Save model
save_model = function(filename, num_iteration = NULL) { save_model = function(filename, num_iteration = NULL) {
# Check if number of iteration is non existent # Check if number of iteration is non existent
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
num_iteration <- self$best_iter num_iteration <- self$best_iter
} }
# Save booster model # Save booster model
lgb.call("LGBM_BoosterSaveModel_R", lgb.call("LGBM_BoosterSaveModel_R",
ret = NULL, ret = NULL,
private$handle, private$handle,
as.integer(num_iteration), as.integer(num_iteration),
lgb.c_str(filename)) lgb.c_str(filename))
# Return self # Return self
return(invisible(self)) return(invisible(self))
}, },
# Save model to string # Save model to string
save_model_to_string = function(num_iteration = NULL) { save_model_to_string = function(num_iteration = NULL) {
# Check if number of iteration is non existent # Check if number of iteration is non existent
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
num_iteration <- self$best_iter num_iteration <- self$best_iter
} }
# Return model string # Return model string
return(lgb.call.return.str("LGBM_BoosterSaveModelToString_R", return(lgb.call.return.str("LGBM_BoosterSaveModelToString_R",
private$handle, private$handle,
as.integer(num_iteration))) as.integer(num_iteration)))
}, },
# Dump model in memory # Dump model in memory
dump_model = function(num_iteration = NULL) { dump_model = function(num_iteration = NULL) {
# Check if number of iteration is non existent # Check if number of iteration is non existent
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
num_iteration <- self$best_iter num_iteration <- self$best_iter
} }
# Return dumped model # Return dumped model
lgb.call.return.str("LGBM_BoosterDumpModel_R", lgb.call.return.str("LGBM_BoosterDumpModel_R",
private$handle, private$handle,
as.integer(num_iteration)) as.integer(num_iteration))
}, },
# Predict on new data # Predict on new data
predict = function(data, predict = function(data,
num_iteration = NULL, num_iteration = NULL,
...@@ -414,34 +414,34 @@ Booster <- R6::R6Class( ...@@ -414,34 +414,34 @@ Booster <- R6::R6Class(
predcontrib = FALSE, predcontrib = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE, ...) { reshape = FALSE, ...) {
# Check if number of iteration is non existent # Check if number of iteration is non existent
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
num_iteration <- self$best_iter num_iteration <- self$best_iter
} }
# Predict on new data # Predict on new data
predictor <- Predictor$new(private$handle, ...) predictor <- Predictor$new(private$handle, ...)
predictor$predict(data, num_iteration, rawscore, predleaf, predcontrib, header, reshape) predictor$predict(data, num_iteration, rawscore, predleaf, predcontrib, header, reshape)
}, },
# Transform into predictor # Transform into predictor
to_predictor = function() { to_predictor = function() {
Predictor$new(private$handle) Predictor$new(private$handle)
}, },
# Used for save # Used for save
raw = NA, raw = NA,
# Save model to temporary file for in-memory saving # Save model to temporary file for in-memory saving
save = function() { save = function() {
# Overwrite model in object # Overwrite model in object
self$raw <- self$save_model_to_string(NULL) self$raw <- self$save_model_to_string(NULL)
} }
), ),
private = list( private = list(
handle = NULL, handle = NULL,
...@@ -459,23 +459,23 @@ Booster <- R6::R6Class( ...@@ -459,23 +459,23 @@ Booster <- R6::R6Class(
set_objective_to_none = FALSE, set_objective_to_none = FALSE,
# Predict data # Predict data
inner_predict = function(idx) { inner_predict = function(idx) {
# Store data name # Store data name
data_name <- private$name_train_set data_name <- private$name_train_set
# Check for id bigger than 1 # Check for id bigger than 1
if (idx > 1) { if (idx > 1) {
data_name <- private$name_valid_sets[[idx - 1]] data_name <- private$name_valid_sets[[idx - 1]]
} }
# Check for unknown dataset (over the maximum provided range) # Check for unknown dataset (over the maximum provided range)
if (idx > private$num_dataset) { if (idx > private$num_dataset) {
stop("data_idx should not be greater than num_dataset") stop("data_idx should not be greater than num_dataset")
} }
# Check for prediction buffer # Check for prediction buffer
if (is.null(private$predict_buffer[[data_name]])) { if (is.null(private$predict_buffer[[data_name]])) {
# Store predictions # Store predictions
npred <- 0L npred <- 0L
npred <- lgb.call("LGBM_BoosterGetNumPredict_R", npred <- lgb.call("LGBM_BoosterGetNumPredict_R",
...@@ -483,12 +483,12 @@ Booster <- R6::R6Class( ...@@ -483,12 +483,12 @@ Booster <- R6::R6Class(
private$handle, private$handle,
as.integer(idx - 1)) as.integer(idx - 1))
private$predict_buffer[[data_name]] <- numeric(npred) private$predict_buffer[[data_name]] <- numeric(npred)
} }
# Check if current iteration was already predicted # Check if current iteration was already predicted
if (!private$is_predicted_cur_iter[[idx]]) { if (!private$is_predicted_cur_iter[[idx]]) {
# Use buffer # Use buffer
private$predict_buffer[[data_name]] <- lgb.call("LGBM_BoosterGetPredict_R", private$predict_buffer[[data_name]] <- lgb.call("LGBM_BoosterGetPredict_R",
ret = private$predict_buffer[[data_name]], ret = private$predict_buffer[[data_name]],
...@@ -496,65 +496,65 @@ Booster <- R6::R6Class( ...@@ -496,65 +496,65 @@ Booster <- R6::R6Class(
as.integer(idx - 1)) as.integer(idx - 1))
private$is_predicted_cur_iter[[idx]] <- TRUE private$is_predicted_cur_iter[[idx]] <- TRUE
} }
# Return prediction buffer # Return prediction buffer
return(private$predict_buffer[[data_name]]) return(private$predict_buffer[[data_name]])
}, },
# Get evaluation information # Get evaluation information
get_eval_info = function() { get_eval_info = function() {
# Check for evaluation names emptiness # Check for evaluation names emptiness
if (is.null(private$eval_names)) { if (is.null(private$eval_names)) {
# Get evaluation names # Get evaluation names
names <- lgb.call.return.str("LGBM_BoosterGetEvalNames_R", names <- lgb.call.return.str("LGBM_BoosterGetEvalNames_R",
private$handle) private$handle)
# Check names' length # Check names' length
if (nchar(names) > 0) { if (nchar(names) > 0) {
# Parse and store privately names # Parse and store privately names
names <- strsplit(names, "\t")[[1]] names <- strsplit(names, "\t")[[1]]
private$eval_names <- names private$eval_names <- names
private$higher_better_inner_eval <- grepl("^ndcg|^auc$", names) private$higher_better_inner_eval <- grepl("^ndcg|^auc$", names)
} }
} }
# Return evaluation names # Return evaluation names
return(private$eval_names) return(private$eval_names)
}, },
# Perform inner evaluation # Perform inner evaluation
inner_eval = function(data_name, data_idx, feval = NULL) { inner_eval = function(data_name, data_idx, feval = NULL) {
# Check for unknown dataset (over the maximum provided range) # Check for unknown dataset (over the maximum provided range)
if (data_idx > private$num_dataset) { if (data_idx > private$num_dataset) {
stop("data_idx should not be greater than num_dataset") stop("data_idx should not be greater than num_dataset")
} }
# Get evaluation information # Get evaluation information
private$get_eval_info() private$get_eval_info()
# Prepare return # Prepare return
ret <- list() ret <- list()
# Check evaluation names existence # Check evaluation names existence
if (length(private$eval_names) > 0) { if (length(private$eval_names) > 0) {
# Create evaluation values # Create evaluation values
tmp_vals <- numeric(length(private$eval_names)) tmp_vals <- numeric(length(private$eval_names))
tmp_vals <- lgb.call("LGBM_BoosterGetEval_R", tmp_vals <- lgb.call("LGBM_BoosterGetEval_R",
ret = tmp_vals, ret = tmp_vals,
private$handle, private$handle,
as.integer(data_idx - 1)) as.integer(data_idx - 1))
# Loop through all evaluation names # Loop through all evaluation names
for (i in seq_along(private$eval_names)) { for (i in seq_along(private$eval_names)) {
# Store evaluation and append to return # Store evaluation and append to return
res <- list() res <- list()
res$data_name <- data_name res$data_name <- data_name
...@@ -562,46 +562,46 @@ Booster <- R6::R6Class( ...@@ -562,46 +562,46 @@ Booster <- R6::R6Class(
res$value <- tmp_vals[i] res$value <- tmp_vals[i]
res$higher_better <- private$higher_better_inner_eval[i] res$higher_better <- private$higher_better_inner_eval[i]
ret <- append(ret, list(res)) ret <- append(ret, list(res))
} }
} }
# Check if there are evaluation metrics # Check if there are evaluation metrics
if (!is.null(feval)) { if (!is.null(feval)) {
# Check if evaluation metric is a function # Check if evaluation metric is a function
if (!is.function(feval)) { if (!is.function(feval)) {
stop("lgb.Booster.eval: feval should be a function") stop("lgb.Booster.eval: feval should be a function")
} }
# Prepare data # Prepare data
data <- private$train_set data <- private$train_set
# Check if data to assess is existing differently # Check if data to assess is existing differently
if (data_idx > 1) { if (data_idx > 1) {
data <- private$valid_sets[[data_idx - 1]] data <- private$valid_sets[[data_idx - 1]]
} }
# Perform function evaluation # Perform function evaluation
res <- feval(private$inner_predict(data_idx), data) res <- feval(private$inner_predict(data_idx), data)
# Check for name correctness # Check for name correctness
if(is.null(res$name) || is.null(res$value) || is.null(res$higher_better)) { if(is.null(res$name) || is.null(res$value) || is.null(res$higher_better)) {
stop("lgb.Booster.eval: custom eval function should return a stop("lgb.Booster.eval: custom eval function should return a
list with attribute (name, value, higher_better)"); list with attribute (name, value, higher_better)");
} }
# Append names and evaluation # Append names and evaluation
res$data_name <- data_name res$data_name <- data_name
ret <- append(ret, list(res)) ret <- append(ret, list(res))
} }
# Return ret # Return ret
return(ret) return(ret)
} }
) )
) )
...@@ -631,7 +631,7 @@ Booster <- R6::R6Class( ...@@ -631,7 +631,7 @@ Booster <- R6::R6Class(
#' #'
#' When \code{predleaf = TRUE}, the output is a matrix object with the #' When \code{predleaf = TRUE}, the output is a matrix object with the
#' number of columns corresponding to the number of trees. #' number of columns corresponding to the number of trees.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -650,7 +650,7 @@ Booster <- R6::R6Class( ...@@ -650,7 +650,7 @@ Booster <- R6::R6Class(
#' learning_rate = 1, #' learning_rate = 1,
#' early_stopping_rounds = 10) #' early_stopping_rounds = 10)
#' preds <- predict(model, test$data) #' preds <- predict(model, test$data)
#' #'
#' @rdname predict.lgb.Booster #' @rdname predict.lgb.Booster
#' @export #' @export
predict.lgb.Booster <- function(object, predict.lgb.Booster <- function(object,
...@@ -660,14 +660,14 @@ predict.lgb.Booster <- function(object, ...@@ -660,14 +660,14 @@ predict.lgb.Booster <- function(object,
predleaf = FALSE, predleaf = FALSE,
predcontrib = FALSE, predcontrib = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE, reshape = FALSE,
...) { ...) {
# Check booster existence # Check booster existence
if (!lgb.is.Booster(object)) { if (!lgb.is.Booster(object)) {
stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster")) stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
} }
# Return booster predictions # Return booster predictions
object$predict(data, object$predict(data,
num_iteration, num_iteration,
...@@ -688,7 +688,7 @@ predict.lgb.Booster <- function(object, ...@@ -688,7 +688,7 @@ predict.lgb.Booster <- function(object,
#' @param model_str a str containing the model #' @param model_str a str containing the model
#' #'
#' @return lgb.Booster #' @return lgb.Booster
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -710,31 +710,31 @@ predict.lgb.Booster <- function(object, ...@@ -710,31 +710,31 @@ predict.lgb.Booster <- function(object,
#' load_booster <- lgb.load(filename = "model.txt") #' load_booster <- lgb.load(filename = "model.txt")
#' model_string <- model$save_model_to_string(NULL) # saves best iteration #' model_string <- model$save_model_to_string(NULL) # saves best iteration
#' load_booster_from_str <- lgb.load(model_str = model_string) #' load_booster_from_str <- lgb.load(model_str = model_string)
#' #'
#' @rdname lgb.load #' @rdname lgb.load
#' @export #' @export
lgb.load <- function(filename = NULL, model_str = NULL){ lgb.load <- function(filename = NULL, model_str = NULL){
if (is.null(filename) && is.null(model_str)) { if (is.null(filename) && is.null(model_str)) {
stop("lgb.load: either filename or model_str must be given") stop("lgb.load: either filename or model_str must be given")
} }
# Load from filename # Load from filename
if (!is.null(filename) && !is.character(filename)) { if (!is.null(filename) && !is.character(filename)) {
stop("lgb.load: filename should be character") stop("lgb.load: filename should be character")
} }
# Return new booster # Return new booster
if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename") if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename")
if (!is.null(filename)) return(invisible(Booster$new(modelfile = filename))) if (!is.null(filename)) return(invisible(Booster$new(modelfile = filename)))
# Load from model_str # Load from model_str
if (!is.null(model_str) && !is.character(model_str)) { if (!is.null(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be character") stop("lgb.load: model_str should be character")
} }
# Return new booster # Return new booster
if (!is.null(model_str)) return(invisible(Booster$new(model_str = model_str))) if (!is.null(model_str)) return(invisible(Booster$new(model_str = model_str)))
} }
#' Save LightGBM model #' Save LightGBM model
...@@ -746,7 +746,7 @@ lgb.load <- function(filename = NULL, model_str = NULL){ ...@@ -746,7 +746,7 @@ lgb.load <- function(filename = NULL, model_str = NULL){
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration #' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' #'
#' @return lgb.Booster #' @return lgb.Booster
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -765,24 +765,24 @@ lgb.load <- function(filename = NULL, model_str = NULL){ ...@@ -765,24 +765,24 @@ lgb.load <- function(filename = NULL, model_str = NULL){
#' learning_rate = 1, #' learning_rate = 1,
#' early_stopping_rounds = 10) #' early_stopping_rounds = 10)
#' lgb.save(model, "model.txt") #' lgb.save(model, "model.txt")
#' #'
#' @rdname lgb.save #' @rdname lgb.save
#' @export #' @export
lgb.save <- function(booster, filename, num_iteration = NULL){ lgb.save <- function(booster, filename, num_iteration = NULL){
# Check if booster is booster # Check if booster is booster
if (!lgb.is.Booster(booster)) { if (!lgb.is.Booster(booster)) {
stop("lgb.save: booster should be an ", sQuote("lgb.Booster")) stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
} }
# Check if file name is character # Check if file name is character
if (!is.character(filename)) { if (!is.character(filename)) {
stop("lgb.save: filename should be a character") stop("lgb.save: filename should be a character")
} }
# Store booster # Store booster
invisible(booster$save_model(filename, num_iteration)) invisible(booster$save_model(filename, num_iteration))
} }
#' Dump LightGBM model to json #' Dump LightGBM model to json
...@@ -793,7 +793,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL){ ...@@ -793,7 +793,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL){
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration #' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' #'
#' @return json format of model #' @return json format of model
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -812,19 +812,19 @@ lgb.save <- function(booster, filename, num_iteration = NULL){ ...@@ -812,19 +812,19 @@ lgb.save <- function(booster, filename, num_iteration = NULL){
#' learning_rate = 1, #' learning_rate = 1,
#' early_stopping_rounds = 10) #' early_stopping_rounds = 10)
#' json_model <- lgb.dump(model) #' json_model <- lgb.dump(model)
#' #'
#' @rdname lgb.dump #' @rdname lgb.dump
#' @export #' @export
lgb.dump <- function(booster, num_iteration = NULL){ lgb.dump <- function(booster, num_iteration = NULL){
# Check if booster is booster # Check if booster is booster
if (!lgb.is.Booster(booster)) { if (!lgb.is.Booster(booster)) {
stop("lgb.save: booster should be an ", sQuote("lgb.Booster")) stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
} }
# Return booster at requested iteration # Return booster at requested iteration
booster$dump_model(num_iteration) booster$dump_model(num_iteration)
} }
#' Get record evaluation result from booster #' Get record evaluation result from booster
...@@ -835,9 +835,9 @@ lgb.dump <- function(booster, num_iteration = NULL){ ...@@ -835,9 +835,9 @@ lgb.dump <- function(booster, num_iteration = NULL){
#' @param eval_name name of evaluation #' @param eval_name name of evaluation
#' @param iters iterations, NULL will return all #' @param iters iterations, NULL will return all
#' @param is_err TRUE will return evaluation error instead #' @param is_err TRUE will return evaluation error instead
#' #'
#' @return vector of evaluation result #' @return vector of evaluation result
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -856,49 +856,49 @@ lgb.dump <- function(booster, num_iteration = NULL){ ...@@ -856,49 +856,49 @@ lgb.dump <- function(booster, num_iteration = NULL){
#' learning_rate = 1, #' learning_rate = 1,
#' early_stopping_rounds = 10) #' early_stopping_rounds = 10)
#' lgb.get.eval.result(model, "test", "l2") #' lgb.get.eval.result(model, "test", "l2")
#' #'
#' @rdname lgb.get.eval.result #' @rdname lgb.get.eval.result
#' @export #' @export
lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) { lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
# Check if booster is booster # Check if booster is booster
if (!lgb.is.Booster(booster)) { if (!lgb.is.Booster(booster)) {
stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result") stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
} }
# Check if data and evaluation name are characters or not # Check if data and evaluation name are characters or not
if (!is.character(data_name) || !is.character(eval_name)) { if (!is.character(data_name) || !is.character(eval_name)) {
stop("lgb.get.eval.result: data_name and eval_name should be characters") stop("lgb.get.eval.result: data_name and eval_name should be characters")
} }
# Check if recorded evaluation is existing # Check if recorded evaluation is existing
if (is.null(booster$record_evals[[data_name]])) { if (is.null(booster$record_evals[[data_name]])) {
stop("lgb.get.eval.result: wrong data name") stop("lgb.get.eval.result: wrong data name")
} }
# Check if evaluation result is existing # Check if evaluation result is existing
if (is.null(booster$record_evals[[data_name]][[eval_name]])) { if (is.null(booster$record_evals[[data_name]][[eval_name]])) {
stop("lgb.get.eval.result: wrong eval name") stop("lgb.get.eval.result: wrong eval name")
} }
# Create result # Create result
result <- booster$record_evals[[data_name]][[eval_name]]$eval result <- booster$record_evals[[data_name]][[eval_name]]$eval
# Check if error is requested # Check if error is requested
if (is_err) { if (is_err) {
result <- booster$record_evals[[data_name]][[eval_name]]$eval_err result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
} }
# Check if iteration is non existant # Check if iteration is non existant
if (is.null(iters)) { if (is.null(iters)) {
return(as.numeric(result)) return(as.numeric(result))
} }
# Parse iteration and booster delta # Parse iteration and booster delta
iters <- as.integer(iters) iters <- as.integer(iters)
delta <- booster$record_evals$start_iter - 1 delta <- booster$record_evals$start_iter - 1
iters <- iters - delta iters <- iters - delta
# Return requested result # Return requested result
as.numeric(result[iters]) as.numeric(result[iters])
} }
...@@ -393,11 +393,11 @@ Dataset <- R6::R6Class( ...@@ -393,11 +393,11 @@ Dataset <- R6::R6Class(
# Check for info name and handle # Check for info name and handle
if (is.null(private$info[[name]])) { if (is.null(private$info[[name]])) {
if (lgb.is.null.handle(private$handle)){ if (lgb.is.null.handle(private$handle)){
stop("Cannot perform getinfo before constructing Dataset.") stop("Cannot perform getinfo before constructing Dataset.")
} }
# Get field size of info # Get field size of info
info_len <- 0L info_len <- 0L
info_len <- lgb.call("LGBM_DatasetGetFieldSize_R", info_len <- lgb.call("LGBM_DatasetGetFieldSize_R",
...@@ -850,7 +850,7 @@ dimnames.lgb.Dataset <- function(x) { ...@@ -850,7 +850,7 @@ dimnames.lgb.Dataset <- function(x) {
#' #'
#' Get a new \code{lgb.Dataset} containing the specified rows of #' Get a new \code{lgb.Dataset} containing the specified rows of
#' original lgb.Dataset object #' original lgb.Dataset object
#' #'
#' @param dataset Object of class "lgb.Dataset" #' @param dataset Object of class "lgb.Dataset"
#' @param idxset a integer vector of indices of rows needed #' @param idxset a integer vector of indices of rows needed
#' @param ... other parameters (currently not used) #' @param ... other parameters (currently not used)
......
...@@ -24,11 +24,11 @@ CVBooster <- R6::R6Class( ...@@ -24,11 +24,11 @@ CVBooster <- R6::R6Class(
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples. #' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label vector of response values. Should be provided only when data is an R-matrix. #' @param label vector of response values. Should be provided only when data is an R-matrix.
#' @param weight vector of response values. If not NULL, will set to dataset #' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function. Examples include #' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber}, #' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass} #' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param eval evaluation function, can be (list of) character or custom eval function #' @param eval evaluation function, can be (list of) character or custom eval function
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals} #' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation #' @param showsd \code{boolean}, whether to show standard deviation of cross validation
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified #' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
#' by the values of outcome labels. #' by the values of outcome labels.
...@@ -45,15 +45,15 @@ CVBooster <- R6::R6Class( ...@@ -45,15 +45,15 @@ CVBooster <- R6::R6Class(
#' \itemize{ #' \itemize{
#' \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}} #' \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}}
#' \item{num_leaves}{number of leaves in one tree. defaults to 127} #' \item{num_leaves}{number of leaves in one tree. defaults to 127}
#' \item{max_depth}{Limit the max depth for tree model. This is used to deal with #' \item{max_depth}{Limit the max depth for tree model. This is used to deal with
#' overfit when #data is small. Tree still grow by leaf-wise.} #' overfit when #data is small. Tree still grow by leaf-wise.}
#' \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to #' \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to
#' the number of real CPU cores, not the number of threads (most #' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).} #' CPU using hyper-threading to generate 2 threads per CPU core).}
#' } #' }
#' #'
#' @return a trained model \code{lgb.CVBooster}. #' @return a trained model \code{lgb.CVBooster}.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -70,7 +70,7 @@ CVBooster <- R6::R6Class( ...@@ -70,7 +70,7 @@ CVBooster <- R6::R6Class(
#' @export #' @export
lgb.cv <- function(params = list(), lgb.cv <- function(params = list(),
data, data,
nrounds = 10, nrounds = 10,
nfold = 3, nfold = 3,
label = NULL, label = NULL,
weight = NULL, weight = NULL,
...@@ -88,7 +88,7 @@ lgb.cv <- function(params = list(), ...@@ -88,7 +88,7 @@ lgb.cv <- function(params = list(),
early_stopping_rounds = NULL, early_stopping_rounds = NULL,
callbacks = list(), callbacks = list(),
...) { ...) {
# Setup temporary variables # Setup temporary variables
addiction_params <- list(...) addiction_params <- list(...)
params <- append(params, addiction_params) params <- append(params, addiction_params)
...@@ -101,31 +101,31 @@ lgb.cv <- function(params = list(), ...@@ -101,31 +101,31 @@ lgb.cv <- function(params = list(),
if (nrounds <= 0) { if (nrounds <= 0) {
stop("nrounds should be greater than zero") stop("nrounds should be greater than zero")
} }
# Check for objective (function or not) # Check for objective (function or not)
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
} }
# Check for loss (function or not) # Check for loss (function or not)
if (is.function(eval)) { if (is.function(eval)) {
feval <- eval feval <- eval
} }
# Check for parameters # Check for parameters
lgb.check.params(params) lgb.check.params(params)
# Init predictor to empty # Init predictor to empty
predictor <- NULL predictor <- NULL
# Check for boosting from a trained model # Check for boosting from a trained model
if (is.character(init_model)) { if (is.character(init_model)) {
predictor <- Predictor$new(init_model) predictor <- Predictor$new(init_model)
} else if (lgb.is.Booster(init_model)) { } else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor() predictor <- init_model$to_predictor()
} }
# Set the iteration to start from / end to (and check for boosting from a trained model, again) # Set the iteration to start from / end to (and check for boosting from a trained model, again)
begin_iteration <- 1 begin_iteration <- 1
if (!is.null(predictor)) { if (!is.null(predictor)) {
...@@ -138,7 +138,7 @@ lgb.cv <- function(params = list(), ...@@ -138,7 +138,7 @@ lgb.cv <- function(params = list(),
} else { } else {
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
} }
# Check for training dataset type correctness # Check for training dataset type correctness
if (!lgb.is.Dataset(data)) { if (!lgb.is.Dataset(data)) {
if (is.null(label)) { if (is.null(label)) {
...@@ -146,49 +146,49 @@ lgb.cv <- function(params = list(), ...@@ -146,49 +146,49 @@ lgb.cv <- function(params = list(),
} }
data <- lgb.Dataset(data, label = label) data <- lgb.Dataset(data, label = label)
} }
# Check for weights # Check for weights
if (!is.null(weight)) { if (!is.null(weight)) {
data$setinfo("weight", weight) data$setinfo("weight", weight)
} }
# Update parameters with parsed parameters # Update parameters with parsed parameters
data$update_params(params) data$update_params(params)
# Create the predictor set # Create the predictor set
data$.__enclos_env__$private$set_predictor(predictor) data$.__enclos_env__$private$set_predictor(predictor)
# Write column names # Write column names
if (!is.null(colnames)) { if (!is.null(colnames)) {
data$set_colnames(colnames) data$set_colnames(colnames)
} }
# Write categorical features # Write categorical features
if (!is.null(categorical_feature)) { if (!is.null(categorical_feature)) {
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
} }
# Construct datasets, if needed # Construct datasets, if needed
data$construct() data$construct()
# Check for folds # Check for folds
if (!is.null(folds)) { if (!is.null(folds)) {
# Check for list of folds or for single value # Check for list of folds or for single value
if (!is.list(folds) || length(folds) < 2) { if (!is.list(folds) || length(folds) < 2) {
stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold") stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold")
} }
# Set number of folds # Set number of folds
nfold <- length(folds) nfold <- length(folds)
} else { } else {
# Check fold value # Check fold value
if (nfold <= 1) { if (nfold <= 1) {
stop(sQuote("nfold"), " must be > 1") stop(sQuote("nfold"), " must be > 1")
} }
# Create folds # Create folds
folds <- generate.cv.folds(nfold, folds <- generate.cv.folds(nfold,
nrow(data), nrow(data),
...@@ -196,19 +196,19 @@ lgb.cv <- function(params = list(), ...@@ -196,19 +196,19 @@ lgb.cv <- function(params = list(),
getinfo(data, "label"), getinfo(data, "label"),
getinfo(data, "group"), getinfo(data, "group"),
params) params)
} }
# Add printing log callback # Add printing log callback
if (verbose > 0 && eval_freq > 0) { if (verbose > 0 && eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
# Add evaluation log callback # Add evaluation log callback
if (record) { if (record) {
callbacks <- add.cb(callbacks, cb.record.evaluation()) callbacks <- add.cb(callbacks, cb.record.evaluation())
} }
# Check for early stopping passed as parameter when adding early stopping callback # Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping") early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping")
if (any(names(params) %in% early_stop)) { if (any(names(params) %in% early_stop)) {
...@@ -222,10 +222,10 @@ lgb.cv <- function(params = list(), ...@@ -222,10 +222,10 @@ lgb.cv <- function(params = list(),
} }
} }
} }
# Categorize callbacks # Categorize callbacks
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
# Construct booster using a list apply, check if requires group or not # Construct booster using a list apply, check if requires group or not
if (!is.list(folds[[1]])) { if (!is.list(folds[[1]])) {
bst_folds <- lapply(seq_along(folds), function(k) { bst_folds <- lapply(seq_along(folds), function(k) {
...@@ -254,107 +254,107 @@ lgb.cv <- function(params = list(), ...@@ -254,107 +254,107 @@ lgb.cv <- function(params = list(),
list(booster = booster) list(booster = booster)
}) })
} }
# Create new booster # Create new booster
cv_booster <- CVBooster$new(bst_folds) cv_booster <- CVBooster$new(bst_folds)
# Callback env # Callback env
env <- CB_ENV$new() env <- CB_ENV$new()
env$model <- cv_booster env$model <- cv_booster
env$begin_iteration <- begin_iteration env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
# Start training model using number of iterations to start and end with # Start training model using number of iterations to start and end with
for (i in seq.int(from = begin_iteration, to = end_iteration)) { for (i in seq.int(from = begin_iteration, to = end_iteration)) {
# Overwrite iteration in environment # Overwrite iteration in environment
env$iteration <- i env$iteration <- i
env$eval_list <- list() env$eval_list <- list()
# Loop through "pre_iter" element # Loop through "pre_iter" element
for (f in cb$pre_iter) { for (f in cb$pre_iter) {
f(env) f(env)
} }
# Update one boosting iteration # Update one boosting iteration
msg <- lapply(cv_booster$boosters, function(fd) { msg <- lapply(cv_booster$boosters, function(fd) {
fd$booster$update(fobj = fobj) fd$booster$update(fobj = fobj)
fd$booster$eval_valid(feval = feval) fd$booster$eval_valid(feval = feval)
}) })
# Prepare collection of evaluation results # Prepare collection of evaluation results
merged_msg <- lgb.merge.cv.result(msg) merged_msg <- lgb.merge.cv.result(msg)
# Write evaluation result in environment # Write evaluation result in environment
env$eval_list <- merged_msg$eval_list env$eval_list <- merged_msg$eval_list
# Check for standard deviation requirement # Check for standard deviation requirement
if(showsd) { if(showsd) {
env$eval_err_list <- merged_msg$eval_err_list env$eval_err_list <- merged_msg$eval_err_list
} }
# Loop through env # Loop through env
for (f in cb$post_iter) { for (f in cb$post_iter) {
f(env) f(env)
} }
# Check for early stopping and break if needed # Check for early stopping and break if needed
if (env$met_early_stop) break if (env$met_early_stop) break
} }
# Return booster # Return booster
return(cv_booster) return(cv_booster)
} }
# Generates random (stratified if needed) CV folds # Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
# Check for group existence # Check for group existence
if (is.null(group)) { if (is.null(group)) {
# Shuffle # Shuffle
rnd_idx <- sample.int(nrows) rnd_idx <- sample.int(nrows)
# Request stratified folds # Request stratified folds
if (isTRUE(stratified) && params$objective %in% c("binary", "multiclass") && length(label) == length(rnd_idx)) { if (isTRUE(stratified) && params$objective %in% c("binary", "multiclass") && length(label) == length(rnd_idx)) {
y <- label[rnd_idx] y <- label[rnd_idx]
y <- factor(y) y <- factor(y)
folds <- lgb.stratified.folds(y, nfold) folds <- lgb.stratified.folds(y, nfold)
} else { } else {
# Make simple non-stratified folds # Make simple non-stratified folds
folds <- list() folds <- list()
# Loop through each fold # Loop through each fold
for (i in seq_len(nfold)) { for (i in seq_len(nfold)) {
kstep <- length(rnd_idx) %/% (nfold - i + 1) kstep <- length(rnd_idx) %/% (nfold - i + 1)
folds[[i]] <- rnd_idx[seq_len(kstep)] folds[[i]] <- rnd_idx[seq_len(kstep)]
rnd_idx <- rnd_idx[-seq_len(kstep)] rnd_idx <- rnd_idx[-seq_len(kstep)]
} }
} }
} else { } else {
# When doing group, stratified is not possible (only random selection) # When doing group, stratified is not possible (only random selection)
if (nfold > length(group)) { if (nfold > length(group)) {
stop("\n\tYou requested too many folds for the number of available groups.\n") stop("\n\tYou requested too many folds for the number of available groups.\n")
} }
# Degroup the groups # Degroup the groups
ungrouped <- inverse.rle(list(lengths = group, values = seq_along(group))) ungrouped <- inverse.rle(list(lengths = group, values = seq_along(group)))
# Can't stratify, shuffle # Can't stratify, shuffle
rnd_idx <- sample.int(length(group)) rnd_idx <- sample.int(length(group))
# Make simple non-stratified folds # Make simple non-stratified folds
folds <- list() folds <- list()
# Loop through each fold # Loop through each fold
for (i in seq_len(nfold)) { for (i in seq_len(nfold)) {
kstep <- length(rnd_idx) %/% (nfold - i + 1) kstep <- length(rnd_idx) %/% (nfold - i + 1)
...@@ -362,12 +362,12 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { ...@@ -362,12 +362,12 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
group = rnd_idx[seq_len(kstep)]) group = rnd_idx[seq_len(kstep)])
rnd_idx <- rnd_idx[-seq_len(kstep)] rnd_idx <- rnd_idx[-seq_len(kstep)]
} }
} }
# Return folds # Return folds
return(folds) return(folds)
} }
# Creates CV folds stratified by the values of y. # Creates CV folds stratified by the values of y.
...@@ -375,7 +375,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { ...@@ -375,7 +375,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
# by always returning an unnamed list of fold indices. # by always returning an unnamed list of fold indices.
#' @importFrom stats quantile #' @importFrom stats quantile
lgb.stratified.folds <- function(y, k = 10) { lgb.stratified.folds <- function(y, k = 10) {
## Group the numeric data based on their magnitudes ## Group the numeric data based on their magnitudes
## and sample within those groups. ## and sample within those groups.
## When the number of samples is low, we may have ## When the number of samples is low, we may have
...@@ -385,51 +385,51 @@ lgb.stratified.folds <- function(y, k = 10) { ...@@ -385,51 +385,51 @@ lgb.stratified.folds <- function(y, k = 10) {
## At most, we will use quantiles. If the sample ## At most, we will use quantiles. If the sample
## is too small, we just do regular unstratified CV ## is too small, we just do regular unstratified CV
if (is.numeric(y)) { if (is.numeric(y)) {
cuts <- length(y) %/% k cuts <- length(y) %/% k
if (cuts < 2) { cuts <- 2 } if (cuts < 2) { cuts <- 2 }
if (cuts > 5) { cuts <- 5 } if (cuts > 5) { cuts <- 5 }
y <- cut(y, y <- cut(y,
unique(stats::quantile(y, probs = seq.int(0, 1, length.out = cuts))), unique(stats::quantile(y, probs = seq.int(0, 1, length.out = cuts))),
include.lowest = TRUE) include.lowest = TRUE)
} }
if (k < length(y)) { if (k < length(y)) {
## Reset levels so that the possible levels and ## Reset levels so that the possible levels and
## the levels in the vector are the same ## the levels in the vector are the same
y <- factor(as.character(y)) y <- factor(as.character(y))
numInClass <- table(y) numInClass <- table(y)
foldVector <- vector(mode = "integer", length(y)) foldVector <- vector(mode = "integer", length(y))
## For each class, balance the fold allocation as far ## For each class, balance the fold allocation as far
## as possible, then resample the remainder. ## as possible, then resample the remainder.
## The final assignment of folds is also randomized. ## The final assignment of folds is also randomized.
for (i in seq_along(numInClass)) { for (i in seq_along(numInClass)) {
## Create a vector of integers from 1:k as many times as possible without ## Create a vector of integers from 1:k as many times as possible without
## going over the number of samples in the class. Note that if the number ## going over the number of samples in the class. Note that if the number
## of samples in a class is less than k, nothing is producd here. ## of samples in a class is less than k, nothing is producd here.
seqVector <- rep(seq_len(k), numInClass[i] %/% k) seqVector <- rep(seq_len(k), numInClass[i] %/% k)
## Add enough random integers to get length(seqVector) == numInClass[i] ## Add enough random integers to get length(seqVector) == numInClass[i]
if (numInClass[i] %% k > 0) { if (numInClass[i] %% k > 0) {
seqVector <- c(seqVector, sample.int(k, numInClass[i] %% k)) seqVector <- c(seqVector, sample.int(k, numInClass[i] %% k))
} }
## Shuffle the integers for fold assignment and assign to this classes's data ## Shuffle the integers for fold assignment and assign to this classes's data
foldVector[y == dimnames(numInClass)$y[i]] <- sample(seqVector) foldVector[y == dimnames(numInClass)$y[i]] <- sample(seqVector)
} }
} else { } else {
foldVector <- seq(along = y) foldVector <- seq(along = y)
} }
# Return data # Return data
out <- split(seq(along = y), foldVector) out <- split(seq(along = y), foldVector)
names(out) <- NULL names(out) <- NULL
...@@ -437,53 +437,53 @@ lgb.stratified.folds <- function(y, k = 10) { ...@@ -437,53 +437,53 @@ lgb.stratified.folds <- function(y, k = 10) {
} }
lgb.merge.cv.result <- function(msg, showsd = TRUE) { lgb.merge.cv.result <- function(msg, showsd = TRUE) {
# Get CV message length # Get CV message length
if (length(msg) == 0) { if (length(msg) == 0) {
stop("lgb.cv: size of cv result error") stop("lgb.cv: size of cv result error")
} }
# Get evaluation message length # Get evaluation message length
eval_len <- length(msg[[1]]) eval_len <- length(msg[[1]])
# Is evaluation message empty? # Is evaluation message empty?
if (eval_len == 0) { if (eval_len == 0) {
stop("lgb.cv: should provide at least one metric for CV") stop("lgb.cv: should provide at least one metric for CV")
} }
# Get evaluation results using a list apply # Get evaluation results using a list apply
eval_result <- lapply(seq_len(eval_len), function(j) { eval_result <- lapply(seq_len(eval_len), function(j) {
as.numeric(lapply(seq_along(msg), function(i) { as.numeric(lapply(seq_along(msg), function(i) {
msg[[i]][[j]]$value })) msg[[i]][[j]]$value }))
}) })
# Get evaluation # Get evaluation
ret_eval <- msg[[1]] ret_eval <- msg[[1]]
# Go through evaluation length items # Go through evaluation length items
for (j in seq_len(eval_len)) { for (j in seq_len(eval_len)) {
ret_eval[[j]]$value <- mean(eval_result[[j]]) ret_eval[[j]]$value <- mean(eval_result[[j]])
} }
# Preinit evaluation error # Preinit evaluation error
ret_eval_err <- NULL ret_eval_err <- NULL
# Check for standard deviation # Check for standard deviation
if (showsd) { if (showsd) {
# Parse standard deviation # Parse standard deviation
for (j in seq_len(eval_len)) { for (j in seq_len(eval_len)) {
ret_eval_err <- c(ret_eval_err, ret_eval_err <- c(ret_eval_err,
sqrt(mean(eval_result[[j]] ^ 2) - mean(eval_result[[j]]) ^ 2)) sqrt(mean(eval_result[[j]] ^ 2) - mean(eval_result[[j]]) ^ 2))
} }
# Convert to list # Convert to list
ret_eval_err <- as.list(ret_eval_err) ret_eval_err <- as.list(ret_eval_err)
} }
# Return errors # Return errors
list(eval_list = ret_eval, list(eval_list = ret_eval,
eval_err_list = ret_eval_err) eval_err_list = ret_eval_err)
} }
#' Compute feature importance in a model #' Compute feature importance in a model
#' #'
#' Creates a \code{data.table} of feature importances in a model. #' Creates a \code{data.table} of feature importances in a model.
#' #'
#' @param model object of class \code{lgb.Booster}. #' @param model object of class \code{lgb.Booster}.
#' @param percentage whether to show importance in relative percentage. #' @param percentage whether to show importance in relative percentage.
#' #'
#' @return #' @return
#' #'
#' For a tree model, a \code{data.table} with the following columns: #' For a tree model, a \code{data.table} with the following columns:
#' \itemize{ #' \itemize{
#' \item \code{Feature} Feature names in the model. #' \item \code{Feature} Feature names in the model.
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#' \item \code{Cover} The number of observation related to this feature. #' \item \code{Cover} The number of observation related to this feature.
#' \item \code{Frequency} The number of times a feature splited in trees. #' \item \code{Frequency} The number of times a feature splited in trees.
#' } #' }
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -29,20 +29,20 @@ ...@@ -29,20 +29,20 @@
#' #'
#' tree_imp1 <- lgb.importance(model, percentage = TRUE) #' tree_imp1 <- lgb.importance(model, percentage = TRUE)
#' tree_imp2 <- lgb.importance(model, percentage = FALSE) #' tree_imp2 <- lgb.importance(model, percentage = FALSE)
#' #'
#' @importFrom magrittr %>% %T>% extract #' @importFrom magrittr %>% %T>% extract
#' @importFrom data.table := #' @importFrom data.table :=
#' @export #' @export
lgb.importance <- function(model, percentage = TRUE) { lgb.importance <- function(model, percentage = TRUE) {
# Check if model is a lightgbm model # Check if model is a lightgbm model
if (!inherits(model, "lgb.Booster")) { if (!inherits(model, "lgb.Booster")) {
stop("'model' has to be an object of class lgb.Booster") stop("'model' has to be an object of class lgb.Booster")
} }
# Setup importance # Setup importance
tree_dt <- lgb.model.dt.tree(model) tree_dt <- lgb.model.dt.tree(model)
# Extract elements # Extract elements
tree_imp <- tree_dt %>% tree_imp <- tree_dt %>%
magrittr::extract(., magrittr::extract(.,
...@@ -51,15 +51,15 @@ lgb.importance <- function(model, percentage = TRUE) { ...@@ -51,15 +51,15 @@ lgb.importance <- function(model, percentage = TRUE) {
by = "split_feature") %T>% by = "split_feature") %T>%
data.table::setnames(., old = "split_feature", new = "Feature") %>% data.table::setnames(., old = "split_feature", new = "Feature") %>%
magrittr::extract(., i = order(Gain, decreasing = TRUE)) magrittr::extract(., i = order(Gain, decreasing = TRUE))
# Check if relative values are requested # Check if relative values are requested
if (percentage) { if (percentage) {
tree_imp[, ":="(Gain = Gain / sum(Gain), tree_imp[, ":="(Gain = Gain / sum(Gain),
Cover = Cover / sum(Cover), Cover = Cover / sum(Cover),
Frequency = Frequency / sum(Frequency))] Frequency = Frequency / sum(Frequency))]
} }
# Return importance table # Return importance table
return(tree_imp) return(tree_imp)
} }
#' Compute feature contribution of prediction #' Compute feature contribution of prediction
#' #'
#' Computes feature contribution components of rawscore prediction. #' Computes feature contribution components of rawscore prediction.
#' #'
#' @param model object of class \code{lgb.Booster}. #' @param model object of class \code{lgb.Booster}.
#' @param data a matrix object or a dgCMatrix object. #' @param data a matrix object or a dgCMatrix object.
#' @param idxset a integer vector of indices of rows needed. #' @param idxset a integer vector of indices of rows needed.
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration. #' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration.
#' #'
#' @return #' @return
#' #'
#' For regression, binary classification and lambdarank model, a \code{list} of \code{data.table} with the following columns: #' For regression, binary classification and lambdarank model, a \code{list} of \code{data.table} with the following columns:
#' \itemize{ #' \itemize{
#' \item \code{Feature} Feature names in the model. #' \item \code{Feature} Feature names in the model.
#' \item \code{Contribution} The total contribution of this feature's splits. #' \item \code{Contribution} The total contribution of this feature's splits.
#' } #' }
#' For multiclass classification, a \code{list} of \code{data.table} with the Feature column and Contribution columns to each class. #' For multiclass classification, a \code{list} of \code{data.table} with the Feature column and Contribution columns to each class.
#' #'
#' @examples #' @examples
#' Sigmoid <- function(x) 1 / (1 + exp(-x)) #' Sigmoid <- function(x) 1 / (1 + exp(-x))
#' Logit <- function(x) log(x / (1 - x)) #' Logit <- function(x) log(x / (1 - x))
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label))) #' setinfo(dtrain, "init_score", rep(Logit(mean(train$label)), length(train$label)))
#' data(agaricus.test, package = "lightgbm") #' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test #' test <- agaricus.test
#' #'
#' params <- list( #' params <- list(
#' objective = "binary" #' objective = "binary"
#' , learning_rate = 0.01 #' , learning_rate = 0.01
...@@ -35,9 +35,9 @@ ...@@ -35,9 +35,9 @@
#' , min_sum_hessian_in_leaf = 1 #' , min_sum_hessian_in_leaf = 1
#' ) #' )
#' model <- lgb.train(params, dtrain, 20) #' model <- lgb.train(params, dtrain, 20)
#' #'
#' tree_interpretation <- lgb.interprete(model, test$data, 1:5) #' tree_interpretation <- lgb.interprete(model, test$data, 1:5)
#' #'
#' @importFrom data.table as.data.table #' @importFrom data.table as.data.table
#' @importFrom magrittr %>% %T>% #' @importFrom magrittr %>% %T>%
#' @export #' @export
...@@ -45,16 +45,16 @@ lgb.interprete <- function(model, ...@@ -45,16 +45,16 @@ lgb.interprete <- function(model,
data, data,
idxset, idxset,
num_iteration = NULL) { num_iteration = NULL) {
# Get tree model # Get tree model
tree_dt <- lgb.model.dt.tree(model, num_iteration) tree_dt <- lgb.model.dt.tree(model, num_iteration)
# Check number of classes # Check number of classes
num_class <- model$.__enclos_env__$private$num_class num_class <- model$.__enclos_env__$private$num_class
# Get vector list # Get vector list
tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset)) tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))
# Get parsed predictions of data # Get parsed predictions of data
leaf_index_mat_list <- model$predict(data[idxset, , drop = FALSE], leaf_index_mat_list <- model$predict(data[idxset, , drop = FALSE],
num_iteration = num_iteration, num_iteration = num_iteration,
...@@ -62,63 +62,63 @@ lgb.interprete <- function(model, ...@@ -62,63 +62,63 @@ lgb.interprete <- function(model,
t(.) %>% t(.) %>%
data.table::as.data.table(.) %>% data.table::as.data.table(.) %>%
lapply(., FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)) lapply(., FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE))
# Get list of trees # Get list of trees
tree_index_mat_list <- lapply(leaf_index_mat_list, tree_index_mat_list <- lapply(leaf_index_mat_list,
FUN = function(x) matrix(seq_len(length(x)) - 1, ncol = num_class, byrow = TRUE)) FUN = function(x) matrix(seq_len(length(x)) - 1, ncol = num_class, byrow = TRUE))
# Sequence over idxset # Sequence over idxset
for (i in seq_along(idxset)) { for (i in seq_along(idxset)) {
tree_interpretation_dt_list[[i]] <- single.row.interprete(tree_dt, num_class, tree_index_mat_list[[i]], leaf_index_mat_list[[i]]) tree_interpretation_dt_list[[i]] <- single.row.interprete(tree_dt, num_class, tree_index_mat_list[[i]], leaf_index_mat_list[[i]])
} }
# Return interpretation list # Return interpretation list
return(tree_interpretation_dt_list) return(tree_interpretation_dt_list)
} }
#' @importFrom data.table data.table #' @importFrom data.table data.table
single.tree.interprete <- function(tree_dt, single.tree.interprete <- function(tree_dt,
tree_id, tree_id,
leaf_id) { leaf_id) {
# Match tree id # Match tree id
single_tree_dt <- tree_dt[tree_index == tree_id, ] single_tree_dt <- tree_dt[tree_index == tree_id, ]
# Get leaves # Get leaves
leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)] leaf_dt <- single_tree_dt[leaf_index == leaf_id, .(leaf_index, leaf_parent, leaf_value)]
# Get nodes # Get nodes
node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)] node_dt <- single_tree_dt[!is.na(split_index), .(split_index, split_feature, node_parent, internal_value)]
# Prepare sequences # Prepare sequences
feature_seq <- character(0) feature_seq <- character(0)
value_seq <- numeric(0) value_seq <- numeric(0)
# Get to root from leaf # Get to root from leaf
leaf_to_root <- function(parent_id, current_value) { leaf_to_root <- function(parent_id, current_value) {
# Store value # Store value
value_seq <<- c(current_value, value_seq) value_seq <<- c(current_value, value_seq)
# Check for null parent id # Check for null parent id
if (!is.na(parent_id)) { if (!is.na(parent_id)) {
# Not null means existing node # Not null means existing node
this_node <- node_dt[split_index == parent_id, ] this_node <- node_dt[split_index == parent_id, ]
feature_seq <<- c(this_node[["split_feature"]], feature_seq) feature_seq <<- c(this_node[["split_feature"]], feature_seq)
leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]]) leaf_to_root(this_node[["node_parent"]], this_node[["internal_value"]])
} }
} }
# Perform leaf to root conversion # Perform leaf to root conversion
leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]]) leaf_to_root(leaf_dt[["leaf_parent"]], leaf_dt[["leaf_value"]])
# Return formatted data.table # Return formatted data.table
data.table::data.table(Feature = feature_seq, Contribution = diff.default(value_seq)) data.table::data.table(Feature = feature_seq, Contribution = diff.default(value_seq))
} }
#' @importFrom data.table rbindlist #' @importFrom data.table rbindlist
...@@ -126,7 +126,7 @@ single.tree.interprete <- function(tree_dt, ...@@ -126,7 +126,7 @@ single.tree.interprete <- function(tree_dt,
multiple.tree.interprete <- function(tree_dt, multiple.tree.interprete <- function(tree_dt,
tree_index, tree_index,
leaf_index) { leaf_index) {
# Apply each trees # Apply each trees
mapply(single.tree.interprete, mapply(single.tree.interprete,
tree_id = tree_index, leaf_id = leaf_index, tree_id = tree_index, leaf_id = leaf_index,
...@@ -135,52 +135,52 @@ multiple.tree.interprete <- function(tree_dt, ...@@ -135,52 +135,52 @@ multiple.tree.interprete <- function(tree_dt,
data.table::rbindlist(., use.names = TRUE) %>% data.table::rbindlist(., use.names = TRUE) %>%
magrittr::extract(., j = .(Contribution = sum(Contribution)), by = "Feature") %>% magrittr::extract(., j = .(Contribution = sum(Contribution)), by = "Feature") %>%
magrittr::extract(., i = order(abs(Contribution), decreasing = TRUE)) magrittr::extract(., i = order(abs(Contribution), decreasing = TRUE))
} }
#' @importFrom data.table set setnames #' @importFrom data.table set setnames
single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) { single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index_mat) {
# Prepare vector list # Prepare vector list
tree_interpretation <- vector(mode = "list", length = num_class) tree_interpretation <- vector(mode = "list", length = num_class)
# Loop throughout each class # Loop throughout each class
for (i in seq_len(num_class)) { for (i in seq_len(num_class)) {
tree_interpretation[[i]] <- multiple.tree.interprete(tree_dt, tree_index_mat[,i], leaf_index_mat[,i]) %T>% { tree_interpretation[[i]] <- multiple.tree.interprete(tree_dt, tree_index_mat[,i], leaf_index_mat[,i]) %T>% {
# Number of classes larger than 1 requires adjustment # Number of classes larger than 1 requires adjustment
if (num_class > 1) { if (num_class > 1) {
data.table::setnames(., old = "Contribution", new = paste("Class", i - 1)) data.table::setnames(., old = "Contribution", new = paste("Class", i - 1))
} }
} }
} }
# Check for numbe rof classes larger than 1 # Check for numbe rof classes larger than 1
if (num_class == 1) { if (num_class == 1) {
# First interpretation element # First interpretation element
tree_interpretation_dt <- tree_interpretation[[1]] tree_interpretation_dt <- tree_interpretation[[1]]
} else { } else {
# Full interpretation elements # Full interpretation elements
tree_interpretation_dt <- Reduce(f = function(x, y) merge(x, y, by = "Feature", all = TRUE), tree_interpretation_dt <- Reduce(f = function(x, y) merge(x, y, by = "Feature", all = TRUE),
x = tree_interpretation) x = tree_interpretation)
# Loop throughout each tree # Loop throughout each tree
for (j in 2:ncol(tree_interpretation_dt)) { for (j in 2:ncol(tree_interpretation_dt)) {
data.table::set(tree_interpretation_dt, data.table::set(tree_interpretation_dt,
i = which(is.na(tree_interpretation_dt[[j]])), i = which(is.na(tree_interpretation_dt[[j]])),
j = j, j = j,
value = 0) value = 0)
} }
} }
# Return interpretation tree # Return interpretation tree
return(tree_interpretation_dt) return(tree_interpretation_dt)
} }
#' Data preparator for LightGBM datasets (numeric) #' Data preparator for LightGBM datasets (numeric)
#' #'
#' Attempts to prepare a clean dataset to prepare to put in a lgb.Dataset. Factors and characters are converted to numeric without integers. Please use \code{lgb.prepare_rules} if you want to apply this transformation to other datasets. #' Attempts to prepare a clean dataset to prepare to put in a lgb.Dataset. Factors and characters are converted to numeric without integers. Please use \code{lgb.prepare_rules} if you want to apply this transformation to other datasets.
#' #'
#' @param data A data.frame or data.table to prepare. #' @param data A data.frame or data.table to prepare.
#' #'
#' @return The cleaned dataset. It must be converted to a matrix format (\code{as.matrix}) for input in lgb.Dataset. #' @return The cleaned dataset. It must be converted to a matrix format (\code{as.matrix}) for input in lgb.Dataset.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(iris) #' data(iris)
#' #'
#' str(iris) #' str(iris)
#' # 'data.frame': 150 obs. of 5 variables: #' # 'data.frame': 150 obs. of 5 variables:
#' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ... #' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 ... #' # $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 ...
#' #'
#' str(lgb.prepare(data = iris)) # Convert all factors/chars to numeric #' str(lgb.prepare(data = iris)) # Convert all factors/chars to numeric
#' # 'data.frame': 150 obs. of 5 variables: #' # 'data.frame': 150 obs. of 5 variables:
#' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ... #' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : num 1 1 1 1 1 1 1 1 1 1 ... #' # $ Species : num 1 1 1 1 1 1 1 1 1 1 ...
#' #'
#' # When lightgbm package is installed, and you do not want to load it #' # When lightgbm package is installed, and you do not want to load it
#' # You can still use the function! #' # You can still use the function!
#' lgb.unloader() #' lgb.unloader()
...@@ -36,57 +36,57 @@ ...@@ -36,57 +36,57 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : num 1 1 1 1 1 1 1 1 1 1 ... #' # $ Species : num 1 1 1 1 1 1 1 1 1 1 ...
#' #'
#' @export #' @export
lgb.prepare <- function(data) { lgb.prepare <- function(data) {
# data.table not behaving like data.frame # data.table not behaving like data.frame
if ("data.table" %in% class(data)) { if ("data.table" %in% class(data)) {
# Get data classes # Get data classes
list_classes <- sapply(data, class) list_classes <- sapply(data, class)
# Convert characters to factors only (we can change them to numeric after) # Convert characters to factors only (we can change them to numeric after)
is_char <- which(list_classes == "character") is_char <- which(list_classes == "character")
if (length(is_char) > 0) { if (length(is_char) > 0) {
data[, (is_char) := lapply(.SD, function(x) {as.numeric(as.factor(x))}), .SDcols = is_char] data[, (is_char) := lapply(.SD, function(x) {as.numeric(as.factor(x))}), .SDcols = is_char]
} }
# Convert factors to numeric (integer is more efficient actually) # Convert factors to numeric (integer is more efficient actually)
is_fact <- c(which(list_classes == "factor"), is_char) is_fact <- c(which(list_classes == "factor"), is_char)
if (length(is_fact) > 0) { if (length(is_fact) > 0) {
data[, (is_fact) := lapply(.SD, function(x) {as.numeric(x)}), .SDcols = is_fact] data[, (is_fact) := lapply(.SD, function(x) {as.numeric(x)}), .SDcols = is_fact]
} }
} else { } else {
# Default routine (data.frame) # Default routine (data.frame)
if ("data.frame" %in% class(data)) { if ("data.frame" %in% class(data)) {
# Get data classes # Get data classes
list_classes <- sapply(data, class) list_classes <- sapply(data, class)
# Convert characters to factors to numeric (integer is more efficient actually) # Convert characters to factors to numeric (integer is more efficient actually)
is_char <- which(list_classes == "character") is_char <- which(list_classes == "character")
if (length(is_char) > 0) { if (length(is_char) > 0) {
data[is_char] <- lapply(data[is_char], function(x) {as.numeric(as.factor(x))}) data[is_char] <- lapply(data[is_char], function(x) {as.numeric(as.factor(x))})
} }
# Convert factors to numeric (integer is more efficient actually) # Convert factors to numeric (integer is more efficient actually)
is_fact <- which(list_classes == "factor") is_fact <- which(list_classes == "factor")
if (length(is_fact) > 0) { if (length(is_fact) > 0) {
data[is_fact] <- lapply(data[is_fact], function(x) {as.numeric(x)}) data[is_fact] <- lapply(data[is_fact], function(x) {as.numeric(x)})
} }
} else { } else {
# What do you think you are doing here? Throw error. # What do you think you are doing here? Throw error.
stop("lgb.prepare2: you provided ", paste(class(data), collapse = " & "), " but data should have class data.frame") stop("lgb.prepare2: you provided ", paste(class(data), collapse = " & "), " but data should have class data.frame")
} }
} }
return(data) return(data)
} }
#' Data preparator for LightGBM datasets (integer) #' Data preparator for LightGBM datasets (integer)
#' #'
#' Attempts to prepare a clean dataset to prepare to put in a lgb.Dataset. Factors and characters are converted to numeric (specifically: integer). Please use \code{lgb.prepare_rules2} if you want to apply this transformation to other datasets. This is useful if you have a specific need for integer dataset instead of numeric dataset. Note that there are programs which do not support integer-only input. Consider this as a half memory technique which is dangerous, especially for LightGBM. #' Attempts to prepare a clean dataset to prepare to put in a lgb.Dataset. Factors and characters are converted to numeric (specifically: integer). Please use \code{lgb.prepare_rules2} if you want to apply this transformation to other datasets. This is useful if you have a specific need for integer dataset instead of numeric dataset. Note that there are programs which do not support integer-only input. Consider this as a half memory technique which is dangerous, especially for LightGBM.
#' #'
#' @param data A data.frame or data.table to prepare. #' @param data A data.frame or data.table to prepare.
#' #'
#' @return The cleaned dataset. It must be converted to a matrix format (\code{as.matrix}) for input in lgb.Dataset. #' @return The cleaned dataset. It must be converted to a matrix format (\code{as.matrix}) for input in lgb.Dataset.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(iris) #' data(iris)
#' #'
#' str(iris) #' str(iris)
#' # 'data.frame': 150 obs. of 5 variables: #' # 'data.frame': 150 obs. of 5 variables:
#' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ... #' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 ... #' # $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 ...
#' #'
#' # Convert all factors/chars to integer #' # Convert all factors/chars to integer
#' str(lgb.prepare2(data = iris)) #' str(lgb.prepare2(data = iris))
#' # 'data.frame': 150 obs. of 5 variables: #' # 'data.frame': 150 obs. of 5 variables:
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : int 1 1 1 1 1 1 1 1 1 1 ... #' # $ Species : int 1 1 1 1 1 1 1 1 1 1 ...
#' #'
#' # When lightgbm package is installed, and you do not want to load it #' # When lightgbm package is installed, and you do not want to load it
#' # You can still use the function! #' # You can still use the function!
#' lgb.unloader() #' lgb.unloader()
...@@ -37,57 +37,57 @@ ...@@ -37,57 +37,57 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : int 1 1 1 1 1 1 1 1 1 1 ... #' # $ Species : int 1 1 1 1 1 1 1 1 1 1 ...
#' #'
#' @export #' @export
lgb.prepare2 <- function(data) { lgb.prepare2 <- function(data) {
# data.table not behaving like data.frame # data.table not behaving like data.frame
if (inherits(data, "data.table")) { if (inherits(data, "data.table")) {
# Get data classes # Get data classes
list_classes <- vapply(data, class, character(1)) list_classes <- vapply(data, class, character(1))
# Convert characters to factors only (we can change them to numeric after) # Convert characters to factors only (we can change them to numeric after)
is_char <- which(list_classes == "character") is_char <- which(list_classes == "character")
if (length(is_char) > 0) { if (length(is_char) > 0) {
data[, (is_char) := lapply(.SD, function(x) {as.integer(as.factor(x))}), .SDcols = is_char] data[, (is_char) := lapply(.SD, function(x) {as.integer(as.factor(x))}), .SDcols = is_char]
} }
# Convert factors to numeric (integer is more efficient actually) # Convert factors to numeric (integer is more efficient actually)
is_fact <- c(which(list_classes == "factor"), is_char) is_fact <- c(which(list_classes == "factor"), is_char)
if (length(is_fact) > 0) { if (length(is_fact) > 0) {
data[, (is_fact) := lapply(.SD, function(x) {as.integer(x)}), .SDcols = is_fact] data[, (is_fact) := lapply(.SD, function(x) {as.integer(x)}), .SDcols = is_fact]
} }
} else { } else {
# Default routine (data.frame) # Default routine (data.frame)
if (inherits(data, "data.frame")) { if (inherits(data, "data.frame")) {
# Get data classes # Get data classes
list_classes <- vapply(data, class, character(1)) list_classes <- vapply(data, class, character(1))
# Convert characters to factors to numeric (integer is more efficient actually) # Convert characters to factors to numeric (integer is more efficient actually)
is_char <- which(list_classes == "character") is_char <- which(list_classes == "character")
if (length(is_char) > 0) { if (length(is_char) > 0) {
data[is_char] <- lapply(data[is_char], function(x) {as.integer(as.factor(x))}) data[is_char] <- lapply(data[is_char], function(x) {as.integer(as.factor(x))})
} }
# Convert factors to numeric (integer is more efficient actually) # Convert factors to numeric (integer is more efficient actually)
is_fact <- which(list_classes == "factor") is_fact <- which(list_classes == "factor")
if (length(is_fact) > 0) { if (length(is_fact) > 0) {
data[is_fact] <- lapply(data[is_fact], function(x) {as.integer(x)}) data[is_fact] <- lapply(data[is_fact], function(x) {as.integer(x)})
} }
} else { } else {
# What do you think you are doing here? Throw error. # What do you think you are doing here? Throw error.
stop("lgb.prepare: you provided ", paste(class(data), collapse = " & "), " but data should have class data.frame") stop("lgb.prepare: you provided ", paste(class(data), collapse = " & "), " but data should have class data.frame")
} }
} }
return(data) return(data)
} }
#' Data preparator for LightGBM datasets with rules (numeric) #' Data preparator for LightGBM datasets with rules (numeric)
#' #'
#' Attempts to prepare a clean dataset to prepare to put in a lgb.Dataset. Factors and characters are converted to numeric. In addition, keeps rules created so you can convert other datasets using this converter. #' Attempts to prepare a clean dataset to prepare to put in a lgb.Dataset. Factors and characters are converted to numeric. In addition, keeps rules created so you can convert other datasets using this converter.
#' #'
#' @param data A data.frame or data.table to prepare. #' @param data A data.frame or data.table to prepare.
#' @param rules A set of rules from the data preparator, if already used. #' @param rules A set of rules from the data preparator, if already used.
#' #'
#' @return A list with the cleaned dataset (\code{data}) and the rules (\code{rules}). The data must be converted to a matrix format (\code{as.matrix}) for input in lgb.Dataset. #' @return A list with the cleaned dataset (\code{data}) and the rules (\code{rules}). The data must be converted to a matrix format (\code{as.matrix}) for input in lgb.Dataset.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(iris) #' data(iris)
#' #'
#' str(iris) #' str(iris)
#' # 'data.frame': 150 obs. of 5 variables: #' # 'data.frame': 150 obs. of 5 variables:
#' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ... #' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 ... #' # $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 ...
#' #'
#' new_iris <- lgb.prepare_rules(data = iris) # Autoconverter #' new_iris <- lgb.prepare_rules(data = iris) # Autoconverter
#' str(new_iris$data) #' str(new_iris$data)
#' # 'data.frame': 150 obs. of 5 variables: #' # 'data.frame': 150 obs. of 5 variables:
...@@ -27,31 +27,31 @@ ...@@ -27,31 +27,31 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : num 1 1 1 1 1 1 1 1 1 1 ... #' # $ Species : num 1 1 1 1 1 1 1 1 1 1 ...
#' #'
#' data(iris) # Erase iris dataset #' data(iris) # Erase iris dataset
#' iris$Species[1] <- "NEW FACTOR" # Introduce junk factor (NA) #' iris$Species[1] <- "NEW FACTOR" # Introduce junk factor (NA)
#' # Warning message: #' # Warning message:
#' # In `[<-.factor`(`*tmp*`, 1, value = c(NA, 1L, 1L, 1L, 1L, 1L, 1L, : #' # In `[<-.factor`(`*tmp*`, 1, value = c(NA, 1L, 1L, 1L, 1L, 1L, 1L, :
#' # invalid factor level, NA generated #' # invalid factor level, NA generated
#' #'
#' # Use conversion using known rules #' # Use conversion using known rules
#' # Unknown factors become 0, excellent for sparse datasets #' # Unknown factors become 0, excellent for sparse datasets
#' newer_iris <- lgb.prepare_rules(data = iris, rules = new_iris$rules) #' newer_iris <- lgb.prepare_rules(data = iris, rules = new_iris$rules)
#' #'
#' # Unknown factor is now zero, perfect for sparse datasets #' # Unknown factor is now zero, perfect for sparse datasets
#' newer_iris$data[1, ] # Species became 0 as it is an unknown factor #' newer_iris$data[1, ] # Species became 0 as it is an unknown factor
#' # Sepal.Length Sepal.Width Petal.Length Petal.Width Species #' # Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#' # 1 5.1 3.5 1.4 0.2 0 #' # 1 5.1 3.5 1.4 0.2 0
#' #'
#' newer_iris$data[1, 5] <- 1 # Put back real initial value #' newer_iris$data[1, 5] <- 1 # Put back real initial value
#' #'
#' # Is the newly created dataset equal? YES! #' # Is the newly created dataset equal? YES!
#' all.equal(new_iris$data, newer_iris$data) #' all.equal(new_iris$data, newer_iris$data)
#' # [1] TRUE #' # [1] TRUE
#' #'
#' # Can we test our own rules? #' # Can we test our own rules?
#' data(iris) # Erase iris dataset #' data(iris) # Erase iris dataset
#' #'
#' # We remapped values differently #' # We remapped values differently
#' personal_rules <- list(Species = c("setosa" = 3, #' personal_rules <- list(Species = c("setosa" = 3,
#' "versicolor" = 2, #' "versicolor" = 2,
...@@ -64,43 +64,43 @@ ...@@ -64,43 +64,43 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : num 3 3 3 3 3 3 3 3 3 3 ... #' # $ Species : num 3 3 3 3 3 3 3 3 3 3 ...
#' #'
#' @importFrom data.table set #' @importFrom data.table set
#' @export #' @export
lgb.prepare_rules <- function(data, rules = NULL) { lgb.prepare_rules <- function(data, rules = NULL) {
# data.table not behaving like data.frame # data.table not behaving like data.frame
if (inherits(data, "data.table")) { if (inherits(data, "data.table")) {
# Must use existing rules # Must use existing rules
if (!is.null(rules)) { if (!is.null(rules)) {
# Loop through rules # Loop through rules
for (i in names(rules)) { for (i in names(rules)) {
data.table::set(data, j = i, value = unname(rules[[i]][data[[i]]])) data.table::set(data, j = i, value = unname(rules[[i]][data[[i]]]))
data[[i]][is.na(data[[i]])] <- 0 # Overwrite NAs by 0s data[[i]][is.na(data[[i]])] <- 0 # Overwrite NAs by 0s
} }
} else { } else {
# Get data classes # Get data classes
list_classes <- vapply(data, class, character(1)) list_classes <- vapply(data, class, character(1))
# Map characters/factors # Map characters/factors
is_fix <- which(list_classes %in% c("character", "factor")) is_fix <- which(list_classes %in% c("character", "factor"))
rules <- list() rules <- list()
# Need to create rules? # Need to create rules?
if (length(is_fix) > 0) { if (length(is_fix) > 0) {
# Go through all characters/factors # Go through all characters/factors
for (i in is_fix) { for (i in is_fix) {
# Store column elsewhere # Store column elsewhere
mini_data <- data[[i]] mini_data <- data[[i]]
# Get unique values # Get unique values
if (is.factor(mini_data)) { if (is.factor(mini_data)) {
mini_unique <- levels(mini_data) # Factor mini_unique <- levels(mini_data) # Factor
...@@ -110,55 +110,55 @@ lgb.prepare_rules <- function(data, rules = NULL) { ...@@ -110,55 +110,55 @@ lgb.prepare_rules <- function(data, rules = NULL) {
mini_unique <- as.factor(unique(mini_data)) # Character mini_unique <- as.factor(unique(mini_data)) # Character
mini_numeric <- as.numeric(mini_unique) # No respect of ordinality mini_numeric <- as.numeric(mini_unique) # No respect of ordinality
} }
# Create rules # Create rules
indexed <- colnames(data)[i] # Index value indexed <- colnames(data)[i] # Index value
rules[[indexed]] <- mini_numeric # Numeric content rules[[indexed]] <- mini_numeric # Numeric content
names(rules[[indexed]]) <- mini_unique # Character equivalent names(rules[[indexed]]) <- mini_unique # Character equivalent
# Apply to real data column # Apply to real data column
data.table::set(data, j = i, value = unname(rules[[indexed]][mini_data])) data.table::set(data, j = i, value = unname(rules[[indexed]][mini_data]))
} }
} }
} }
} else { } else {
# Must use existing rules # Must use existing rules
if (!is.null(rules)) { if (!is.null(rules)) {
# Loop through rules # Loop through rules
for (i in names(rules)) { for (i in names(rules)) {
data[[i]] <- unname(rules[[i]][data[[i]]]) data[[i]] <- unname(rules[[i]][data[[i]]])
data[[i]][is.na(data[[i]])] <- 0 # Overwrite NAs by 0s data[[i]][is.na(data[[i]])] <- 0 # Overwrite NAs by 0s
} }
} else { } else {
# Default routine (data.frame) # Default routine (data.frame)
if (inherits(data, "data.frame")) { if (inherits(data, "data.frame")) {
# Get data classes # Get data classes
list_classes <- vapply(data, class, character(1)) list_classes <- vapply(data, class, character(1))
# Map characters/factors # Map characters/factors
is_fix <- which(list_classes %in% c("character", "factor")) is_fix <- which(list_classes %in% c("character", "factor"))
rules <- list() rules <- list()
# Need to create rules? # Need to create rules?
if (length(is_fix) > 0) { if (length(is_fix) > 0) {
# Go through all characters/factors # Go through all characters/factors
for (i in is_fix) { for (i in is_fix) {
# Store column elsewhere # Store column elsewhere
mini_data <- data[[i]] mini_data <- data[[i]]
# Get unique values # Get unique values
if (is.factor(mini_data)) { if (is.factor(mini_data)) {
mini_unique <- levels(mini_data) # Factor mini_unique <- levels(mini_data) # Factor
...@@ -168,30 +168,30 @@ lgb.prepare_rules <- function(data, rules = NULL) { ...@@ -168,30 +168,30 @@ lgb.prepare_rules <- function(data, rules = NULL) {
mini_unique <- as.factor(unique(mini_data)) # Character mini_unique <- as.factor(unique(mini_data)) # Character
mini_numeric <- as.numeric(mini_unique) # No respect of ordinality mini_numeric <- as.numeric(mini_unique) # No respect of ordinality
} }
# Create rules # Create rules
indexed <- colnames(data)[i] # Index value indexed <- colnames(data)[i] # Index value
rules[[indexed]] <- mini_numeric # Numeric content rules[[indexed]] <- mini_numeric # Numeric content
names(rules[[indexed]]) <- mini_unique # Character equivalent names(rules[[indexed]]) <- mini_unique # Character equivalent
# Apply to real data column # Apply to real data column
data[[i]] <- unname(rules[[indexed]][mini_data]) data[[i]] <- unname(rules[[indexed]][mini_data])
} }
} }
} else { } else {
# What do you think you are doing here? Throw error. # What do you think you are doing here? Throw error.
stop("lgb.prepare: you provided ", paste(class(data), collapse = " & "), " but data should have class data.frame") stop("lgb.prepare: you provided ", paste(class(data), collapse = " & "), " but data should have class data.frame")
} }
} }
} }
return(list(data = data, rules = rules)) return(list(data = data, rules = rules))
} }
#' Data preparator for LightGBM datasets with rules (integer) #' Data preparator for LightGBM datasets with rules (integer)
#' #'
#' Attempts to prepare a clean dataset to prepare to put in a lgb.Dataset. Factors and characters are converted to numeric (specifically: integer). In addition, keeps rules created so you can convert other datasets using this converter. This is useful if you have a specific need for integer dataset instead of numeric dataset. Note that there are programs which do not support integer-only input. Consider this as a half memory technique which is dangerous, especially for LightGBM. #' Attempts to prepare a clean dataset to prepare to put in a lgb.Dataset. Factors and characters are converted to numeric (specifically: integer). In addition, keeps rules created so you can convert other datasets using this converter. This is useful if you have a specific need for integer dataset instead of numeric dataset. Note that there are programs which do not support integer-only input. Consider this as a half memory technique which is dangerous, especially for LightGBM.
#' #'
#' @param data A data.frame or data.table to prepare. #' @param data A data.frame or data.table to prepare.
#' @param rules A set of rules from the data preparator, if already used. #' @param rules A set of rules from the data preparator, if already used.
#' #'
#' @return A list with the cleaned dataset (\code{data}) and the rules (\code{rules}). The data must be converted to a matrix format (\code{as.matrix}) for input in lgb.Dataset. #' @return A list with the cleaned dataset (\code{data}) and the rules (\code{rules}). The data must be converted to a matrix format (\code{as.matrix}) for input in lgb.Dataset.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(iris) #' data(iris)
#' #'
#' str(iris) #' str(iris)
#' # 'data.frame': 150 obs. of 5 variables: #' # 'data.frame': 150 obs. of 5 variables:
#' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ... #' # $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 ... #' # $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 ...
#' #'
#' new_iris <- lgb.prepare_rules2(data = iris) # Autoconverter #' new_iris <- lgb.prepare_rules2(data = iris) # Autoconverter
#' str(new_iris$data) #' str(new_iris$data)
#' # 'data.frame': 150 obs. of 5 variables: #' # 'data.frame': 150 obs. of 5 variables:
...@@ -27,31 +27,31 @@ ...@@ -27,31 +27,31 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : int 1 1 1 1 1 1 1 1 1 1 ... #' # $ Species : int 1 1 1 1 1 1 1 1 1 1 ...
#' #'
#' data(iris) # Erase iris dataset #' data(iris) # Erase iris dataset
#' iris$Species[1] <- "NEW FACTOR" # Introduce junk factor (NA) #' iris$Species[1] <- "NEW FACTOR" # Introduce junk factor (NA)
#' # Warning message: #' # Warning message:
#' # In `[<-.factor`(`*tmp*`, 1, value = c(NA, 1L, 1L, 1L, 1L, 1L, 1L, : #' # In `[<-.factor`(`*tmp*`, 1, value = c(NA, 1L, 1L, 1L, 1L, 1L, 1L, :
#' # invalid factor level, NA generated #' # invalid factor level, NA generated
#' #'
#' # Use conversion using known rules #' # Use conversion using known rules
#' # Unknown factors become 0, excellent for sparse datasets #' # Unknown factors become 0, excellent for sparse datasets
#' newer_iris <- lgb.prepare_rules2(data = iris, rules = new_iris$rules) #' newer_iris <- lgb.prepare_rules2(data = iris, rules = new_iris$rules)
#' #'
#' # Unknown factor is now zero, perfect for sparse datasets #' # Unknown factor is now zero, perfect for sparse datasets
#' newer_iris$data[1, ] # Species became 0 as it is an unknown factor #' newer_iris$data[1, ] # Species became 0 as it is an unknown factor
#' # Sepal.Length Sepal.Width Petal.Length Petal.Width Species #' # Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#' # 1 5.1 3.5 1.4 0.2 0 #' # 1 5.1 3.5 1.4 0.2 0
#' #'
#' newer_iris$data[1, 5] <- 1 # Put back real initial value #' newer_iris$data[1, 5] <- 1 # Put back real initial value
#' #'
#' # Is the newly created dataset equal? YES! #' # Is the newly created dataset equal? YES!
#' all.equal(new_iris$data, newer_iris$data) #' all.equal(new_iris$data, newer_iris$data)
#' # [1] TRUE #' # [1] TRUE
#' #'
#' # Can we test our own rules? #' # Can we test our own rules?
#' data(iris) # Erase iris dataset #' data(iris) # Erase iris dataset
#' #'
#' # We remapped values differently #' # We remapped values differently
#' personal_rules <- list(Species = c("setosa" = 3L, #' personal_rules <- list(Species = c("setosa" = 3L,
#' "versicolor" = 2L, #' "versicolor" = 2L,
...@@ -64,43 +64,43 @@ ...@@ -64,43 +64,43 @@
#' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ... #' # $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ... #' # $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#' # $ Species : int 3 3 3 3 3 3 3 3 3 3 ... #' # $ Species : int 3 3 3 3 3 3 3 3 3 3 ...
#' #'
#' @importFrom data.table set #' @importFrom data.table set
#' @export #' @export
lgb.prepare_rules2 <- function(data, rules = NULL) { lgb.prepare_rules2 <- function(data, rules = NULL) {
# data.table not behaving like data.frame # data.table not behaving like data.frame
if (inherits(data, "data.table")) { if (inherits(data, "data.table")) {
# Must use existing rules # Must use existing rules
if (!is.null(rules)) { if (!is.null(rules)) {
# Loop through rules # Loop through rules
for (i in names(rules)) { for (i in names(rules)) {
data.table::set(data, j = i, value = unname(rules[[i]][data[[i]]])) data.table::set(data, j = i, value = unname(rules[[i]][data[[i]]]))
data[[i]][is.na(data[[i]])] <- 0L # Overwrite NAs by 0s as integer data[[i]][is.na(data[[i]])] <- 0L # Overwrite NAs by 0s as integer
} }
} else { } else {
# Get data classes # Get data classes
list_classes <- vapply(data, class, character(1)) list_classes <- vapply(data, class, character(1))
# Map characters/factors # Map characters/factors
is_fix <- which(list_classes %in% c("character", "factor")) is_fix <- which(list_classes %in% c("character", "factor"))
rules <- list() rules <- list()
# Need to create rules? # Need to create rules?
if (length(is_fix) > 0) { if (length(is_fix) > 0) {
# Go through all characters/factors # Go through all characters/factors
for (i in is_fix) { for (i in is_fix) {
# Store column elsewhere # Store column elsewhere
mini_data <- data[[i]] mini_data <- data[[i]]
# Get unique values # Get unique values
if (is.factor(mini_data)) { if (is.factor(mini_data)) {
mini_unique <- levels(mini_data) # Factor mini_unique <- levels(mini_data) # Factor
...@@ -109,55 +109,55 @@ lgb.prepare_rules2 <- function(data, rules = NULL) { ...@@ -109,55 +109,55 @@ lgb.prepare_rules2 <- function(data, rules = NULL) {
mini_unique <- as.factor(unique(mini_data)) # Character mini_unique <- as.factor(unique(mini_data)) # Character
mini_numeric <- as.integer(mini_unique) # No respect of ordinality mini_numeric <- as.integer(mini_unique) # No respect of ordinality
} }
# Create rules # Create rules
indexed <- colnames(data)[i] # Index value indexed <- colnames(data)[i] # Index value
rules[[indexed]] <- mini_numeric # Numeric content rules[[indexed]] <- mini_numeric # Numeric content
names(rules[[indexed]]) <- mini_unique # Character equivalent names(rules[[indexed]]) <- mini_unique # Character equivalent
# Apply to real data column # Apply to real data column
data.table::set(data, j = i, value = unname(rules[[indexed]][mini_data])) data.table::set(data, j = i, value = unname(rules[[indexed]][mini_data]))
} }
} }
} }
} else { } else {
# Must use existing rules # Must use existing rules
if (!is.null(rules)) { if (!is.null(rules)) {
# Loop through rules # Loop through rules
for (i in names(rules)) { for (i in names(rules)) {
data[[i]] <- unname(rules[[i]][data[[i]]]) data[[i]] <- unname(rules[[i]][data[[i]]])
data[[i]][is.na(data[[i]])] <- 0L # Overwrite NAs by 0s as integer data[[i]][is.na(data[[i]])] <- 0L # Overwrite NAs by 0s as integer
} }
} else { } else {
# Default routine (data.frame) # Default routine (data.frame)
if (inherits(data, "data.frame")) { if (inherits(data, "data.frame")) {
# Get data classes # Get data classes
list_classes <- vapply(data, class, character(1)) list_classes <- vapply(data, class, character(1))
# Map characters/factors # Map characters/factors
is_fix <- which(list_classes %in% c("character", "factor")) is_fix <- which(list_classes %in% c("character", "factor"))
rules <- list() rules <- list()
# Need to create rules? # Need to create rules?
if (length(is_fix) > 0) { if (length(is_fix) > 0) {
# Go through all characters/factors # Go through all characters/factors
for (i in is_fix) { for (i in is_fix) {
# Store column elsewhere # Store column elsewhere
mini_data <- data[[i]] mini_data <- data[[i]]
# Get unique values # Get unique values
if (is.factor(mini_data)) { if (is.factor(mini_data)) {
mini_unique <- levels(mini_data) # Factor mini_unique <- levels(mini_data) # Factor
...@@ -166,30 +166,30 @@ lgb.prepare_rules2 <- function(data, rules = NULL) { ...@@ -166,30 +166,30 @@ lgb.prepare_rules2 <- function(data, rules = NULL) {
mini_unique <- as.factor(unique(mini_data)) # Character mini_unique <- as.factor(unique(mini_data)) # Character
mini_numeric <- as.integer(mini_unique) # No respect of ordinality mini_numeric <- as.integer(mini_unique) # No respect of ordinality
} }
# Create rules # Create rules
indexed <- colnames(data)[i] # Index value indexed <- colnames(data)[i] # Index value
rules[[indexed]] <- mini_numeric # Numeric content rules[[indexed]] <- mini_numeric # Numeric content
names(rules[[indexed]]) <- mini_unique # Character equivalent names(rules[[indexed]]) <- mini_unique # Character equivalent
# Apply to real data column # Apply to real data column
data[[i]] <- unname(rules[[indexed]][mini_data]) data[[i]] <- unname(rules[[indexed]][mini_data])
} }
} }
} else { } else {
# What do you think you are doing here? Throw error. # What do you think you are doing here? Throw error.
stop("lgb.prepare: you provided ", paste(class(data), collapse = " & "), " but data should have class data.frame") stop("lgb.prepare: you provided ", paste(class(data), collapse = " & "), " but data should have class data.frame")
} }
} }
} }
return(list(data = data, rules = rules)) return(list(data = data, rules = rules))
} }
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
#' @description Logic to train with LightGBM #' @description Logic to train with LightGBM
#' @inheritParams lgb_shared_params #' @inheritParams lgb_shared_params
#' @param valids a list of \code{lgb.Dataset} objects, used for validation #' @param valids a list of \code{lgb.Dataset} objects, used for validation
#' @param obj objective function, can be character or custom objective function. Examples include #' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber}, #' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass} #' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}
#' @param eval evaluation function, can be (a list of) character or custom eval function #' @param eval evaluation function, can be (a list of) character or custom eval function
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals} #' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset #' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int #' @param categorical_feature list of str or int
#' type int represents index, #' type int represents index,
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
#' \itemize{ #' \itemize{
#' \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}} #' \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}}
#' \item{num_leaves}{number of leaves in one tree. defaults to 127} #' \item{num_leaves}{number of leaves in one tree. defaults to 127}
#' \item{max_depth}{Limit the max depth for tree model. This is used to deal with #' \item{max_depth}{Limit the max depth for tree model. This is used to deal with
#' overfit when #data is small. Tree still grow by leaf-wise.} #' overfit when #data is small. Tree still grow by leaf-wise.}
#' \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to #' \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to
#' the number of real CPU cores, not the number of threads (most #' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).} #' CPU using hyper-threading to generate 2 threads per CPU core).}
#' } #' }
#' @return a trained booster model \code{lgb.Booster}. #' @return a trained booster model \code{lgb.Booster}.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
#' min_data = 1, #' min_data = 1,
#' learning_rate = 1, #' learning_rate = 1,
#' early_stopping_rounds = 10) #' early_stopping_rounds = 10)
#' #'
#' @export #' @export
lgb.train <- function(params = list(), lgb.train <- function(params = list(),
data, data,
...@@ -60,7 +60,7 @@ lgb.train <- function(params = list(), ...@@ -60,7 +60,7 @@ lgb.train <- function(params = list(),
callbacks = list(), callbacks = list(),
reset_data = FALSE, reset_data = FALSE,
...) { ...) {
# Setup temporary variables # Setup temporary variables
additional_params <- list(...) additional_params <- list(...)
params <- append(params, additional_params) params <- append(params, additional_params)
...@@ -69,7 +69,7 @@ lgb.train <- function(params = list(), ...@@ -69,7 +69,7 @@ lgb.train <- function(params = list(),
params <- lgb.check.eval(params, eval) params <- lgb.check.eval(params, eval)
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if (nrounds <= 0) { if (nrounds <= 0) {
stop("nrounds should be greater than zero") stop("nrounds should be greater than zero")
} }
...@@ -79,25 +79,25 @@ lgb.train <- function(params = list(), ...@@ -79,25 +79,25 @@ lgb.train <- function(params = list(),
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
} }
# Check for loss (function or not) # Check for loss (function or not)
if (is.function(eval)) { if (is.function(eval)) {
feval <- eval feval <- eval
} }
# Check for parameters # Check for parameters
lgb.check.params(params) lgb.check.params(params)
# Init predictor to empty # Init predictor to empty
predictor <- NULL predictor <- NULL
# Check for boosting from a trained model # Check for boosting from a trained model
if (is.character(init_model)) { if (is.character(init_model)) {
predictor <- Predictor$new(init_model) predictor <- Predictor$new(init_model)
} else if (lgb.is.Booster(init_model)) { } else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor() predictor <- init_model$to_predictor()
} }
# Set the iteration to start from / end to (and check for boosting from a trained model, again) # Set the iteration to start from / end to (and check for boosting from a trained model, again)
begin_iteration <- 1 begin_iteration <- 1
if (!is.null(predictor)) { if (!is.null(predictor)) {
...@@ -110,89 +110,89 @@ lgb.train <- function(params = list(), ...@@ -110,89 +110,89 @@ lgb.train <- function(params = list(),
} else { } else {
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
} }
# Check for training dataset type correctness # Check for training dataset type correctness
if (!lgb.is.Dataset(data)) { if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object") stop("lgb.train: data only accepts lgb.Dataset object")
} }
# Check for validation dataset type correctness # Check for validation dataset type correctness
if (length(valids) > 0) { if (length(valids) > 0) {
# One or more validation dataset # One or more validation dataset
# Check for list as input and type correctness by object # Check for list as input and type correctness by object
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1)))) { if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements") stop("lgb.train: valids must be a list of lgb.Dataset elements")
} }
# Attempt to get names # Attempt to get names
evnames <- names(valids) evnames <- names(valids)
# Check for names existance # Check for names existance
if (is.null(evnames) || !all(nzchar(evnames))) { if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of the valids must have a name tag") stop("lgb.train: each element of the valids must have a name tag")
} }
} }
# Update parameters with parsed parameters # Update parameters with parsed parameters
data$update_params(params) data$update_params(params)
# Create the predictor set # Create the predictor set
data$.__enclos_env__$private$set_predictor(predictor) data$.__enclos_env__$private$set_predictor(predictor)
# Write column names # Write column names
if (!is.null(colnames)) { if (!is.null(colnames)) {
data$set_colnames(colnames) data$set_colnames(colnames)
} }
# Write categorical features # Write categorical features
if (!is.null(categorical_feature)) { if (!is.null(categorical_feature)) {
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
} }
# Construct datasets, if needed # Construct datasets, if needed
data$construct() data$construct()
vaild_contain_train <- FALSE vaild_contain_train <- FALSE
train_data_name <- "train" train_data_name <- "train"
reduced_valid_sets <- list() reduced_valid_sets <- list()
# Parse validation datasets # Parse validation datasets
if (length(valids) > 0) { if (length(valids) > 0) {
# Loop through all validation datasets using name # Loop through all validation datasets using name
for (key in names(valids)) { for (key in names(valids)) {
# Use names to get validation datasets # Use names to get validation datasets
valid_data <- valids[[key]] valid_data <- valids[[key]]
# Check for duplicate train/validation dataset # Check for duplicate train/validation dataset
if (identical(data, valid_data)) { if (identical(data, valid_data)) {
vaild_contain_train <- TRUE vaild_contain_train <- TRUE
train_data_name <- key train_data_name <- key
next next
} }
# Update parameters, data # Update parameters, data
valid_data$update_params(params) valid_data$update_params(params)
valid_data$set_reference(data) valid_data$set_reference(data)
reduced_valid_sets[[key]] <- valid_data reduced_valid_sets[[key]] <- valid_data
} }
} }
# Add printing log callback # Add printing log callback
if (verbose > 0 && eval_freq > 0) { if (verbose > 0 && eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
# Add evaluation log callback # Add evaluation log callback
if (record && length(valids) > 0) { if (record && length(valids) > 0) {
callbacks <- add.cb(callbacks, cb.record.evaluation()) callbacks <- add.cb(callbacks, cb.record.evaluation())
} }
# Check for early stopping passed as parameter when adding early stopping callback # Check for early stopping passed as parameter when adding early stopping callback
early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping") early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping")
if (any(names(params) %in% early_stop)) { if (any(names(params) %in% early_stop)) {
...@@ -206,83 +206,83 @@ lgb.train <- function(params = list(), ...@@ -206,83 +206,83 @@ lgb.train <- function(params = list(),
} }
} }
} }
# "Categorize" callbacks # "Categorize" callbacks
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
# Construct booster with datasets # Construct booster with datasets
booster <- Booster$new(params = params, train_set = data) booster <- Booster$new(params = params, train_set = data)
if (vaild_contain_train) { booster$set_train_data_name(train_data_name) } if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
for (key in names(reduced_valid_sets)) { for (key in names(reduced_valid_sets)) {
booster$add_valid(reduced_valid_sets[[key]], key) booster$add_valid(reduced_valid_sets[[key]], key)
} }
# Callback env # Callback env
env <- CB_ENV$new() env <- CB_ENV$new()
env$model <- booster env$model <- booster
env$begin_iteration <- begin_iteration env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
# Start training model using number of iterations to start and end with # Start training model using number of iterations to start and end with
for (i in seq.int(from = begin_iteration, to = end_iteration)) { for (i in seq.int(from = begin_iteration, to = end_iteration)) {
# Overwrite iteration in environment # Overwrite iteration in environment
env$iteration <- i env$iteration <- i
env$eval_list <- list() env$eval_list <- list()
# Loop through "pre_iter" element # Loop through "pre_iter" element
for (f in cb$pre_iter) { for (f in cb$pre_iter) {
f(env) f(env)
} }
# Update one boosting iteration # Update one boosting iteration
booster$update(fobj = fobj) booster$update(fobj = fobj)
# Prepare collection of evaluation results # Prepare collection of evaluation results
eval_list <- list() eval_list <- list()
# Collection: Has validation dataset? # Collection: Has validation dataset?
if (length(valids) > 0) { if (length(valids) > 0) {
# Validation has training dataset? # Validation has training dataset?
if (vaild_contain_train) { if (vaild_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval = feval)) eval_list <- append(eval_list, booster$eval_train(feval = feval))
} }
# Has no validation dataset # Has no validation dataset
eval_list <- append(eval_list, booster$eval_valid(feval = feval)) eval_list <- append(eval_list, booster$eval_valid(feval = feval))
} }
# Write evaluation result in environment # Write evaluation result in environment
env$eval_list <- eval_list env$eval_list <- eval_list
# Loop through env # Loop through env
for (f in cb$post_iter) { for (f in cb$post_iter) {
f(env) f(env)
} }
# Check for early stopping and break if needed # Check for early stopping and break if needed
if (env$met_early_stop) break if (env$met_early_stop) break
} }
# Check for booster model conversion to predictor model # Check for booster model conversion to predictor model
if (reset_data) { if (reset_data) {
# Store temporarily model data elsewhere # Store temporarily model data elsewhere
booster_old <- list(best_iter = booster$best_iter, booster_old <- list(best_iter = booster$best_iter,
best_score = booster$best_score, best_score = booster$best_score,
record_evals = booster$record_evals) record_evals = booster$record_evals)
# Reload model # Reload model
booster <- lgb.load(model_str = booster$save_model_to_string()) booster <- lgb.load(model_str = booster$save_model_to_string())
booster$best_iter <- booster_old$best_iter booster$best_iter <- booster_old$best_iter
booster$best_score <- booster_old$best_score booster$best_score <- booster_old$best_score
booster$record_evals <- booster_old$record_evals booster$record_evals <- booster_old$record_evals
} }
# Return booster # Return booster
return(booster) return(booster)
} }
#' LightGBM unloading error fix #' LightGBM unloading error fix
#' #'
#' Attempts to unload LightGBM packages so you can remove objects cleanly without having to restart R. This is useful for instance if an object becomes stuck for no apparent reason and you do not want to restart R to fix the lost object. #' Attempts to unload LightGBM packages so you can remove objects cleanly without having to restart R. This is useful for instance if an object becomes stuck for no apparent reason and you do not want to restart R to fix the lost object.
#' #'
#' @param restore Whether to reload \code{LightGBM} immediately after detaching from R. Defaults to \code{TRUE} which means automatically reload \code{LightGBM} once unloading is performed. #' @param restore Whether to reload \code{LightGBM} immediately after detaching from R. Defaults to \code{TRUE} which means automatically reload \code{LightGBM} once unloading is performed.
#' @param wipe Whether to wipe all \code{lgb.Dataset} and \code{lgb.Booster} from the global environment. Defaults to \code{FALSE} which means to not remove them. #' @param wipe Whether to wipe all \code{lgb.Dataset} and \code{lgb.Booster} from the global environment. Defaults to \code{FALSE} which means to not remove them.
#' @param envir The environment to perform wiping on if \code{wipe == TRUE}. Defaults to \code{.GlobalEnv} which is the global environment. #' @param envir The environment to perform wiping on if \code{wipe == TRUE}. Defaults to \code{.GlobalEnv} which is the global environment.
#' #'
#' @return NULL invisibly. #' @return NULL invisibly.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -28,16 +28,16 @@ ...@@ -28,16 +28,16 @@
#' lgb.unloader(restore = FALSE, wipe = FALSE, envir = .GlobalEnv) #' lgb.unloader(restore = FALSE, wipe = FALSE, envir = .GlobalEnv)
#' rm(model, dtrain, dtest) # Not needed if wipe = TRUE #' rm(model, dtrain, dtest) # Not needed if wipe = TRUE
#' gc() # Not needed if wipe = TRUE #' gc() # Not needed if wipe = TRUE
#' #'
#' library(lightgbm) #' library(lightgbm)
#' # Do whatever you want again with LightGBM without object clashing #' # Do whatever you want again with LightGBM without object clashing
#' #'
#' @export #' @export
lgb.unloader <- function(restore = TRUE, wipe = FALSE, envir = .GlobalEnv) { lgb.unloader <- function(restore = TRUE, wipe = FALSE, envir = .GlobalEnv) {
# Unload package # Unload package
try(detach("package:lightgbm", unload = TRUE), silent = TRUE) try(detach("package:lightgbm", unload = TRUE), silent = TRUE)
# Should we wipe variables? (lgb.Booster, lgb.Dataset) # Should we wipe variables? (lgb.Booster, lgb.Dataset)
if (wipe) { if (wipe) {
boosters <- Filter(function(x) inherits(get(x, envir = envir), "lgb.Booster"), ls(envir = envir)) boosters <- Filter(function(x) inherits(get(x, envir = envir), "lgb.Booster"), ls(envir = envir))
...@@ -45,12 +45,12 @@ lgb.unloader <- function(restore = TRUE, wipe = FALSE, envir = .GlobalEnv) { ...@@ -45,12 +45,12 @@ lgb.unloader <- function(restore = TRUE, wipe = FALSE, envir = .GlobalEnv) {
rm(list = c(boosters, datasets), envir = envir) rm(list = c(boosters, datasets), envir = envir)
gc(verbose = FALSE) gc(verbose = FALSE)
} }
# Load package back? # Load package back?
if (restore) { if (restore) {
library(lightgbm) library(lightgbm)
} }
invisible() invisible()
} }
...@@ -29,21 +29,21 @@ NULL ...@@ -29,21 +29,21 @@ NULL
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example #' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#' \itemize{ #' \itemize{
#' \item{valids}{a list of \code{lgb.Dataset} objects, used for validation} #' \item{valids}{a list of \code{lgb.Dataset} objects, used for validation}
#' \item{obj}{objective function, can be character or custom objective function. Examples include #' \item{obj}{objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber}, #' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}} #' \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}}
#' \item{eval}{evaluation function, can be (a list of) character or custom eval function} #' \item{eval}{evaluation function, can be (a list of) character or custom eval function}
#' \item{record}{Boolean, TRUE will record iteration message to \code{booster$record_evals}} #' \item{record}{Boolean, TRUE will record iteration message to \code{booster$record_evals}}
#' \item{colnames}{feature names, if not null, will use this to overwrite the names in dataset} #' \item{colnames}{feature names, if not null, will use this to overwrite the names in dataset}
#' \item{categorical_feature}{list of str or int. type int represents index, type str represents feature names} #' \item{categorical_feature}{list of str or int. type int represents index, type str represents feature names}
#' \item{reset_data}{Boolean, setting it to TRUE (not the default value) will transform the booster model #' \item{reset_data}{Boolean, setting it to TRUE (not the default value) will transform the booster model
#' into a predictor model which frees up memory and the original datasets} #' into a predictor model which frees up memory and the original datasets}
#' \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}} #' \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}}
#' \item{num_leaves}{number of leaves in one tree. defaults to 127} #' \item{num_leaves}{number of leaves in one tree. defaults to 127}
#' \item{max_depth}{Limit the max depth for tree model. This is used to deal with #' \item{max_depth}{Limit the max depth for tree model. This is used to deal with
#' overfit when #data is small. Tree still grow by leaf-wise.} #' overfit when #data is small. Tree still grow by leaf-wise.}
#' \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to #' \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to
#' the number of real CPU cores, not the number of threads (most #' the number of real CPU cores, not the number of threads (most
#' CPU using hyper-threading to generate 2 threads per CPU core).} #' CPU using hyper-threading to generate 2 threads per CPU core).}
#' } #' }
#' @export #' @export
...@@ -59,7 +59,7 @@ lightgbm <- function(data, ...@@ -59,7 +59,7 @@ lightgbm <- function(data,
init_model = NULL, init_model = NULL,
callbacks = list(), callbacks = list(),
...) { ...) {
# Set data to a temporary variable # Set data to a temporary variable
dtrain <- data dtrain <- data
if (nrounds <= 0) { if (nrounds <= 0) {
...@@ -75,15 +75,15 @@ lightgbm <- function(data, ...@@ -75,15 +75,15 @@ lightgbm <- function(data,
if (verbose > 0) { if (verbose > 0) {
valids$train = dtrain valids$train = dtrain
} }
# Train a model using the regular way # Train a model using the regular way
bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq = eval_freq, bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq = eval_freq,
early_stopping_rounds = early_stopping_rounds, early_stopping_rounds = early_stopping_rounds,
init_model = init_model, callbacks = callbacks, ...) init_model = init_model, callbacks = callbacks, ...)
# Store model under a specific name # Store model under a specific name
bst$save_model(save_name) bst$save_model(save_name)
# Return booster # Return booster
return(bst) return(bst)
} }
...@@ -152,7 +152,7 @@ NULL ...@@ -152,7 +152,7 @@ NULL
#' #'
#' @references #' @references
#' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing #' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
#' #'
#' S. Moro, P. Cortez and P. Rita. (2014) #' S. Moro, P. Cortez and P. Rita. (2014)
#' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems #' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems
#' #'
......
#' readRDS for lgb.Booster models #' readRDS for lgb.Booster models
#' #'
#' Attempts to load a model using RDS. #' Attempts to load a model using RDS.
#' #'
#' @param file a connection or the name of the file where the R object is saved to or read from. #' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param refhook a hook function for handling reference objects. #' @param refhook a hook function for handling reference objects.
#' #'
#' @return lgb.Booster. #' @return lgb.Booster.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -26,31 +26,31 @@ ...@@ -26,31 +26,31 @@
#' early_stopping_rounds = 10) #' early_stopping_rounds = 10)
#' saveRDS.lgb.Booster(model, "model.rds") #' saveRDS.lgb.Booster(model, "model.rds")
#' new_model <- readRDS.lgb.Booster("model.rds") #' new_model <- readRDS.lgb.Booster("model.rds")
#' #'
#' @export #' @export
readRDS.lgb.Booster <- function(file = "", refhook = NULL) { readRDS.lgb.Booster <- function(file = "", refhook = NULL) {
# Read RDS file # Read RDS file
object <- readRDS(file = file, refhook = refhook) object <- readRDS(file = file, refhook = refhook)
# Check if object has the model stored # Check if object has the model stored
if (!is.na(object$raw)) { if (!is.na(object$raw)) {
# Create temporary model for the model loading # Create temporary model for the model loading
object2 <- lgb.load(model_str = object$raw) object2 <- lgb.load(model_str = object$raw)
# Restore best iteration and recorded evaluations # Restore best iteration and recorded evaluations
object2$best_iter <- object$best_iter object2$best_iter <- object$best_iter
object2$record_evals <- object$record_evals object2$record_evals <- object$record_evals
# Return newly loaded object # Return newly loaded object
return(object2) return(object2)
} else { } else {
# Return RDS loaded object # Return RDS loaded object
return(object) return(object)
} }
} }
#' saveRDS for lgb.Booster models #' saveRDS for lgb.Booster models
#' #'
#' Attempts to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not. #' Attempts to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not.
#' #'
#' @param object R object to serialize. #' @param object R object to serialize.
#' @param file a connection or the name of the file where the R object is saved to or read from. #' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save. #' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save.
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection. #' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection.
#' @param refhook a hook function for handling reference objects. #' @param refhook a hook function for handling reference objects.
#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}. #' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.
#' #'
#' @return NULL invisibly. #' @return NULL invisibly.
#' #'
#' @examples #' @examples
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package = "lightgbm") #' data(agaricus.train, package = "lightgbm")
...@@ -40,13 +40,13 @@ saveRDS.lgb.Booster <- function(object, ...@@ -40,13 +40,13 @@ saveRDS.lgb.Booster <- function(object,
compress = TRUE, compress = TRUE,
refhook = NULL, refhook = NULL,
raw = TRUE) { raw = TRUE) {
# Check if object has a raw value (and if the user wants to store the raw) # Check if object has a raw value (and if the user wants to store the raw)
if (is.na(object$raw) && raw) { if (is.na(object$raw) && raw) {
# Save model # Save model
object$save() object$save()
# Save RDS # Save RDS
saveRDS(object, saveRDS(object,
file = file, file = file,
...@@ -54,12 +54,12 @@ saveRDS.lgb.Booster <- function(object, ...@@ -54,12 +54,12 @@ saveRDS.lgb.Booster <- function(object,
version = version, version = version,
compress = compress, compress = compress,
refhook = refhook) refhook = refhook)
# Free model from memory # Free model from memory
object$raw <- NA object$raw <- NA
} else { } else {
# Save as usual # Save as usual
saveRDS(object, saveRDS(object,
file = file, file = file,
...@@ -67,7 +67,7 @@ saveRDS.lgb.Booster <- function(object, ...@@ -67,7 +67,7 @@ saveRDS.lgb.Booster <- function(object,
version = version, version = version,
compress = compress, compress = compress,
refhook = refhook) refhook = refhook)
} }
} }
...@@ -11,18 +11,18 @@ lgb.is.null.handle <- function(x) { ...@@ -11,18 +11,18 @@ lgb.is.null.handle <- function(x) {
} }
lgb.encode.char <- function(arr, len) { lgb.encode.char <- function(arr, len) {
if (!is.raw(arr)) { if (!is.raw(arr)) {
stop("lgb.encode.char: Can only encode from raw type") # Not an object of type raw stop("lgb.encode.char: Can only encode from raw type") # Not an object of type raw
} }
rawToChar(arr[seq_len(len)]) # Return the conversion of raw type to character type rawToChar(arr[seq_len(len)]) # Return the conversion of raw type to character type
} }
lgb.call <- function(fun_name, ret, ...) { lgb.call <- function(fun_name, ret, ...) {
# Set call state to a zero value # Set call state to a zero value
call_state <- 0L call_state <- 0L
# Check for a ret call # Check for a ret call
if (!is.null(ret)) { if (!is.null(ret)) {
call_state <- .Call(fun_name, ..., ret, call_state, PACKAGE = "lib_lightgbm") # Call with ret call_state <- .Call(fun_name, ..., ret, call_state, PACKAGE = "lib_lightgbm") # Call with ret
...@@ -38,7 +38,7 @@ lgb.call <- function(fun_name, ret, ...) { ...@@ -38,7 +38,7 @@ lgb.call <- function(fun_name, ret, ...) {
act_len <- 0L act_len <- 0L
err_msg <- raw(buf_len) err_msg <- raw(buf_len)
err_msg <- .Call("LGBM_GetLastError_R", buf_len, act_len, err_msg, PACKAGE = "lib_lightgbm") err_msg <- .Call("LGBM_GetLastError_R", buf_len, act_len, err_msg, PACKAGE = "lib_lightgbm")
# Check error buffer # Check error buffer
if (act_len > buf_len) { if (act_len > buf_len) {
buf_len <- act_len buf_len <- act_len
...@@ -49,169 +49,169 @@ lgb.call <- function(fun_name, ret, ...) { ...@@ -49,169 +49,169 @@ lgb.call <- function(fun_name, ret, ...) {
err_msg, err_msg,
PACKAGE = "lib_lightgbm") PACKAGE = "lib_lightgbm")
} }
# Return error # Return error
stop("api error: ", lgb.encode.char(err_msg, act_len)) stop("api error: ", lgb.encode.char(err_msg, act_len))
} }
return(ret) return(ret)
} }
lgb.call.return.str <- function(fun_name, ...) { lgb.call.return.str <- function(fun_name, ...) {
# Create buffer # Create buffer
buf_len <- as.integer(1024 * 1024) buf_len <- as.integer(1024 * 1024)
act_len <- 0L act_len <- 0L
buf <- raw(buf_len) buf <- raw(buf_len)
# Call buffer # Call buffer
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len) buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
# Check for buffer content # Check for buffer content
if (act_len > buf_len) { if (act_len > buf_len) {
buf_len <- act_len buf_len <- act_len
buf <- raw(buf_len) buf <- raw(buf_len)
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len) buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
} }
# Return encoded character # Return encoded character
return(lgb.encode.char(buf, act_len)) return(lgb.encode.char(buf, act_len))
} }
lgb.params2str <- function(params, ...) { lgb.params2str <- function(params, ...) {
# Check for a list as input # Check for a list as input
if (!is.list(params)) { if (!is.list(params)) {
stop("params must be a list") stop("params must be a list")
} }
# Split parameter names # Split parameter names
names(params) <- gsub("\\.", "_", names(params)) names(params) <- gsub("\\.", "_", names(params))
# Merge parameters from the params and the dots-expansion # Merge parameters from the params and the dots-expansion
dot_params <- list(...) dot_params <- list(...)
names(dot_params) <- gsub("\\.", "_", names(dot_params)) names(dot_params) <- gsub("\\.", "_", names(dot_params))
# Check for identical parameters # Check for identical parameters
if (length(intersect(names(params), names(dot_params))) > 0) { if (length(intersect(names(params), names(dot_params))) > 0) {
stop("Same parameters in ", sQuote("params"), " and in the call are not allowed. Please check your ", sQuote("params"), " list") stop("Same parameters in ", sQuote("params"), " and in the call are not allowed. Please check your ", sQuote("params"), " list")
} }
# Merge parameters # Merge parameters
params <- c(params, dot_params) params <- c(params, dot_params)
# Setup temporary variable # Setup temporary variable
ret <- list() ret <- list()
# Perform key value join # Perform key value join
for (key in names(params)) { for (key in names(params)) {
# Join multi value first # Join multi value first
val <- paste0(format(params[[key]], scientific = FALSE), collapse = ",") val <- paste0(format(params[[key]], scientific = FALSE), collapse = ",")
if (nchar(val) <= 0) next # Skip join if (nchar(val) <= 0) next # Skip join
# Join key value # Join key value
pair <- paste0(c(key, val), collapse = "=") pair <- paste0(c(key, val), collapse = "=")
ret <- c(ret, pair) ret <- c(ret, pair)
} }
# Check ret length # Check ret length
if (length(ret) == 0) { if (length(ret) == 0) {
# Return empty string # Return empty string
lgb.c_str("") lgb.c_str("")
} else { } else {
# Return string separated by a space per element # Return string separated by a space per element
lgb.c_str(paste0(ret, collapse = " ")) lgb.c_str(paste0(ret, collapse = " "))
} }
} }
lgb.c_str <- function(x) { lgb.c_str <- function(x) {
# Perform character to raw conversion # Perform character to raw conversion
ret <- charToRaw(as.character(x)) ret <- charToRaw(as.character(x))
ret <- c(ret, as.raw(0)) ret <- c(ret, as.raw(0))
ret ret
} }
lgb.check.r6.class <- function(object, name) { lgb.check.r6.class <- function(object, name) {
# Check for non-existence of R6 class or named class # Check for non-existence of R6 class or named class
all(c("R6", name) %in% class(object)) all(c("R6", name) %in% class(object))
} }
lgb.check.params <- function(params) { lgb.check.params <- function(params) {
# To-do # To-do
params # Currently return params because this is not finalized params # Currently return params because this is not finalized
} }
lgb.check.obj <- function(params, obj) { lgb.check.obj <- function(params, obj) {
# List known objectives in a vector # List known objectives in a vector
OBJECTIVES <- c("regression", "regression_l1", "regression_l2", "mean_squared_error", "mse", "l2_root", "root_mean_squared_error", "rmse", OBJECTIVES <- c("regression", "regression_l1", "regression_l2", "mean_squared_error", "mse", "l2_root", "root_mean_squared_error", "rmse",
"mean_absolute_error", "mae", "quantile", "mean_absolute_error", "mae", "quantile",
"huber", "fair", "poisson", "binary", "lambdarank", "huber", "fair", "poisson", "binary", "lambdarank",
"multiclass", "softmax", "multiclassova", "multiclass_ova", "ova", "ovr", "multiclass", "softmax", "multiclassova", "multiclass_ova", "ova", "ovr",
"xentropy", "cross_entropy", "xentlambda", "cross_entropy_lambda", "mean_absolute_percentage_error", "mape", "xentropy", "cross_entropy", "xentlambda", "cross_entropy_lambda", "mean_absolute_percentage_error", "mape",
"gamma", "tweedie") "gamma", "tweedie")
# Check whether the objective is empty or not, and take it from params if needed # Check whether the objective is empty or not, and take it from params if needed
if (!is.null(obj)) { params$objective <- obj } if (!is.null(obj)) { params$objective <- obj }
# Check whether the objective is a character # Check whether the objective is a character
if (is.character(params$objective)) { if (is.character(params$objective)) {
# If the objective is a character, check if it is a known objective # If the objective is a character, check if it is a known objective
if (!(params$objective %in% OBJECTIVES)) { if (!(params$objective %in% OBJECTIVES)) {
# Interrupt on unknown objective name # Interrupt on unknown objective name
stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")") stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
} }
} else if (!is.function(params$objective)) { } else if (!is.function(params$objective)) {
# If objective is not a character nor a function, then stop # If objective is not a character nor a function, then stop
stop("lgb.check.obj: objective should be a character or a function") stop("lgb.check.obj: objective should be a character or a function")
} }
# Return parameters # Return parameters
return(params) return(params)
} }
lgb.check.eval <- function(params, eval) { lgb.check.eval <- function(params, eval) {
# Check if metric is null, if yes put a list instead # Check if metric is null, if yes put a list instead
if (is.null(params$metric)) { if (is.null(params$metric)) {
params$metric <- list() params$metric <- list()
} }
# Check if evaluation metric is null, if not then append it # Check if evaluation metric is null, if not then append it
if (!is.null(eval)) { if (!is.null(eval)) {
# Append metric if character or list # Append metric if character or list
if (is.character(eval) || is.list(eval)) { if (is.character(eval) || is.list(eval)) {
# Append metrics # Append metrics
params$metric <- append(params$metric, eval) params$metric <- append(params$metric, eval)
} }
} }
# Return parameters # Return parameters
return(params) return(params)
} }
...@@ -5,7 +5,7 @@ library(data.table) ...@@ -5,7 +5,7 @@ library(data.table)
library(lightgbm) library(lightgbm)
# Load data and look at the structure # Load data and look at the structure
# #
# Classes 'data.table' and 'data.frame': 4521 obs. of 17 variables: # Classes 'data.table' and 'data.frame': 4521 obs. of 17 variables:
# $ age : int 30 33 35 30 59 35 36 39 41 43 ... # $ age : int 30 33 35 30 59 35 36 39 41 43 ...
# $ job : chr "unemployed" "services" "management" "management" ... # $ job : chr "unemployed" "services" "management" "management" ...
...@@ -30,7 +30,7 @@ str(bank) ...@@ -30,7 +30,7 @@ str(bank)
# We must now transform the data to fit in LightGBM # We must now transform the data to fit in LightGBM
# For this task, we use lgb.prepare # For this task, we use lgb.prepare
# The function transforms the data into a fittable data # The function transforms the data into a fittable data
# #
# Classes 'data.table' and 'data.frame': 4521 obs. of 17 variables: # Classes 'data.table' and 'data.frame': 4521 obs. of 17 variables:
# $ age : int 30 33 35 30 59 35 36 39 41 43 ... # $ age : int 30 33 35 30 59 35 36 39 41 43 ...
# $ job : chr "unemployed" "services" "management" "management" ... # $ job : chr "unemployed" "services" "management" "management" ...
......
...@@ -5,7 +5,7 @@ library(data.table) ...@@ -5,7 +5,7 @@ library(data.table)
library(lightgbm) library(lightgbm)
# Load data and look at the structure # Load data and look at the structure
# #
# Classes 'data.table' and 'data.frame': 4521 obs. of 17 variables: # Classes 'data.table' and 'data.frame': 4521 obs. of 17 variables:
# $ age : int 30 33 35 30 59 35 36 39 41 43 ... # $ age : int 30 33 35 30 59 35 36 39 41 43 ...
# $ job : chr "unemployed" "services" "management" "management" ... # $ job : chr "unemployed" "services" "management" "management" ...
...@@ -34,7 +34,7 @@ bank_test <- bank[4001:4521, ] ...@@ -34,7 +34,7 @@ bank_test <- bank[4001:4521, ]
# We must now transform the data to fit in LightGBM # We must now transform the data to fit in LightGBM
# For this task, we use lgb.prepare # For this task, we use lgb.prepare
# The function transforms the data into a fittable data # The function transforms the data into a fittable data
# #
# Classes 'data.table' and 'data.frame': 521 obs. of 17 variables: # Classes 'data.table' and 'data.frame': 521 obs. of 17 variables:
# $ age : int 53 36 58 26 34 55 55 34 41 38 ... # $ age : int 53 36 58 26 34 55 55 34 41 38 ...
# $ job : num 1 10 10 9 10 2 2 3 3 4 ... # $ job : num 1 10 10 9 10 2 2 3 3 4 ...
......
...@@ -36,35 +36,35 @@ preds_builtin <- predict(model_builtin, test[, 1:4], rawscore = TRUE) ...@@ -36,35 +36,35 @@ preds_builtin <- predict(model_builtin, test[, 1:4], rawscore = TRUE)
# User defined objective function, given prediction, return gradient and second order gradient # User defined objective function, given prediction, return gradient and second order gradient
custom_multiclass_obj = function(preds, dtrain) { custom_multiclass_obj = function(preds, dtrain) {
labels = getinfo(dtrain, "label") labels = getinfo(dtrain, "label")
# preds is a matrix with rows corresponding to samples and colums corresponding to choices # preds is a matrix with rows corresponding to samples and colums corresponding to choices
preds = matrix(preds, nrow = length(labels)) preds = matrix(preds, nrow = length(labels))
# to prevent overflow, normalize preds by row # to prevent overflow, normalize preds by row
preds = preds - apply(preds, 1, max) preds = preds - apply(preds, 1, max)
prob = exp(preds) / rowSums(exp(preds)) prob = exp(preds) / rowSums(exp(preds))
# compute gradient # compute gradient
grad = prob grad = prob
grad[cbind(1:length(labels), labels + 1)] = grad[cbind(1:length(labels), labels + 1)] - 1 grad[cbind(1:length(labels), labels + 1)] = grad[cbind(1:length(labels), labels + 1)] - 1
# compute hessian (approximation) # compute hessian (approximation)
hess = 2 * prob * (1 - prob) hess = 2 * prob * (1 - prob)
return(list(grad = grad, hess = hess)) return(list(grad = grad, hess = hess))
} }
# define custom metric # define custom metric
custom_multiclass_metric = function(preds, dtrain) { custom_multiclass_metric = function(preds, dtrain) {
labels = getinfo(dtrain, "label") labels = getinfo(dtrain, "label")
preds = matrix(preds, nrow = length(labels)) preds = matrix(preds, nrow = length(labels))
preds = preds - apply(preds, 1, max) preds = preds - apply(preds, 1, max)
prob = exp(preds) / rowSums(exp(preds)) prob = exp(preds) / rowSums(exp(preds))
return(list(name = "error", return(list(name = "error",
value = -mean(log(prob[cbind(1:length(labels), labels + 1)])), value = -mean(log(prob[cbind(1:length(labels), labels + 1)])),
higher_better = FALSE)) higher_better = FALSE))
} }
model_custom <- lgb.train(list(), model_custom <- lgb.train(list(),
......
...@@ -23,7 +23,7 @@ lgb.cv(params = list(), data, nrounds = 10, nfold = 3, label = NULL, ...@@ -23,7 +23,7 @@ lgb.cv(params = list(), data, nrounds = 10, nfold = 3, label = NULL,
\item{weight}{vector of response values. If not NULL, will set to dataset} \item{weight}{vector of response values. If not NULL, will set to dataset}
\item{obj}{objective function, can be character or custom objective function. Examples include \item{obj}{objective function, can be character or custom objective function. Examples include
\code{regression}, \code{regression_l1}, \code{huber}, \code{regression}, \code{regression_l1}, \code{huber},
\code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}} \code{binary}, \code{lambdarank}, \code{multiclass}, \code{multiclass}}
...@@ -66,10 +66,10 @@ List of callback functions that are applied at each iteration.} ...@@ -66,10 +66,10 @@ List of callback functions that are applied at each iteration.}
\itemize{ \itemize{
\item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}} \item{boosting}{Boosting type. \code{"gbdt"} or \code{"dart"}}
\item{num_leaves}{number of leaves in one tree. defaults to 127} \item{num_leaves}{number of leaves in one tree. defaults to 127}
\item{max_depth}{Limit the max depth for tree model. This is used to deal with \item{max_depth}{Limit the max depth for tree model. This is used to deal with
overfit when #data is small. Tree still grow by leaf-wise.} overfit when #data is small. Tree still grow by leaf-wise.}
\item{num_threads}{Number of threads for LightGBM. For the best speed, set this to \item{num_threads}{Number of threads for LightGBM. For the best speed, set this to
the number of real CPU cores, not the number of threads (most the number of real CPU cores, not the number of threads (most
CPU using hyper-threading to generate 2 threads per CPU core).} CPU using hyper-threading to generate 2 threads per CPU core).}
}} }}
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment