Commit 2035c54b authored by Kirill Sevastyanenko's avatar Kirill Sevastyanenko Committed by Guolin Ke
Browse files

Stylistic changes r package (#184)

* src & callbacks

* lgb.Booster and utils

* cv

* wip lgb.Dataset

* lgb.Dataset

* lgb.Predictor

* lgb.train

* typos

* add flags to facilitate macosx compilation

* fix basic_string template error with clang

* most unfortunate mode of development

* fixup tests

* last test

* roxygen

* roxygen v5.x.x
parent c2ba086c
...@@ -363,5 +363,8 @@ ENV/ ...@@ -363,5 +363,8 @@ ENV/
# Rope project settings # Rope project settings
.ropeproject .ropeproject
# R testing artefact
lightgbm.model
# macOS # macOS
.DS_Store .DS_Store
CB_ENV <- R6Class( CB_ENV <- R6Class(
"lgb.cb_env", "lgb.cb_env",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
model=NULL, model = NULL,
iteration=NULL, iteration = NULL,
begin_iteration=NULL, begin_iteration = NULL,
end_iteration=NULL, end_iteration = NULL,
eval_list=list(), eval_list = list(),
eval_err_list=list(), eval_err_list = list(),
best_iter=-1, best_iter = -1,
met_early_stop=FALSE met_early_stop = FALSE
) )
) )
cb.reset.parameters <- function(new_params) { cb.reset.parameters <- function(new_params) {
if (typeof(new_params) != "list") if (!is.list(new_params)) { stop(sQuote("new_params"), " must be a list") }
stop("'new_params' must be a list")
pnames <- gsub("\\.", "_", names(new_params)) pnames <- gsub("\\.", "_", names(new_params))
nrounds <- NULL nrounds <- NULL
...@@ -23,113 +22,96 @@ cb.reset.parameters <- function(new_params) { ...@@ -23,113 +22,96 @@ cb.reset.parameters <- function(new_params) {
init <- function(env) { init <- function(env) {
nrounds <<- env$end_iteration - env$begin_iteration + 1 nrounds <<- env$end_iteration - env$begin_iteration + 1
if (is.null(env$model)) if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
stop("Env should has 'model'")
# Some parameters are not allowed to be changed, # Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos # since changing them would simply wreck some chaos
not_allowed <- pnames %in% not_allowed <- c("num_class", "metric", "boosting_type")
c('num_class', 'metric', 'boosting_type') if (any(pnames %in% not_allowed)) {
if (any(not_allowed)) stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
stop('Parameters ', paste(pnames[not_allowed]), " cannot be changed during boosting.") }
for (n in pnames) { for (n in pnames) {
p <- new_params[[n]] p <- new_params[[n]]
if (is.function(p)) { if (is.function(p)) {
if (length(formals(p)) != 2) if (length(formals(p)) != 2)
stop("Parameter '", n, "' is a function but not of two arguments") stop("Parameter ", sQuote(n), " is a function but not of two arguments")
} else if (is.numeric(p) || is.character(p)) { } else if (is.numeric(p) || is.character(p)) {
if (length(p) != nrounds) if (length(p) != nrounds)
stop("Length of '", n, "' has to be equal to 'nrounds'") stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
} else { } else {
stop("Parameter '", n, "' is not a function or a vector") stop("Parameter ", sQuote(n), " is not a function or a vector")
} }
} }
} }
callback <- function(env) { callback <- function(env) {
if (is.null(nrounds)) if (is.null(nrounds)) { init(env) }
init(env)
i <- env$iteration - env$begin_iteration i <- env$iteration - env$begin_iteration
pars <- lapply(new_params, function(p) { pars <- lapply(new_params, function(p) {
if (is.function(p)) if (is.function(p)) { return(p(i, nrounds)) }
return(p(i, nrounds))
p[i] p[i]
}) })
# to-do check pars # to-do check pars
if (!is.null(env$model)) { if (!is.null(env$model)) { env$model$reset_parameter(pars) }
env$model$reset_parameter(pars)
}
} }
attr(callback, 'call') <- match.call() attr(callback, 'call') <- match.call()
attr(callback, 'is_pre_iteration') <- TRUE attr(callback, 'is_pre_iteration') <- TRUE
attr(callback, 'name') <- 'cb.reset.parameters' attr(callback, 'name') <- 'cb.reset.parameters'
return(callback) callback
} }
# Format the evaluation metric string # Format the evaluation metric string
format.eval.string <- function(eval_res, eval_err=NULL) { format.eval.string <- function(eval_res, eval_err = NULL) {
if (is.null(eval_res)) if (is.null(eval_res) || length(eval_res) == 0) { stop('no evaluation results') }
stop('no evaluation results')
if (length(eval_res) == 0)
stop('no evaluation results')
if (!is.null(eval_err)) { if (!is.null(eval_err)) {
res <- sprintf('%s\'s %s:%g+%g', eval_res$data_name, eval_res$name, eval_res$value, eval_err) sprintf('%s\'s %s:%g+%g', eval_res$data_name, eval_res$name, eval_res$value, eval_err)
} else { } else {
res <- sprintf('%s\'s %s:%g', eval_res$data_name, eval_res$name, eval_res$value) sprintf('%s\'s %s:%g', eval_res$data_name, eval_res$name, eval_res$value)
} }
return(res)
} }
merge.eval.string <- function(env){ merge.eval.string <- function(env) {
if(length(env$eval_list) <= 0){ if (length(env$eval_list) <= 0) { return("") }
return("") msg <- list(sprintf('[%d]:', env$iteration))
}
msg <- list(sprintf('[%d]:',env$iteration))
is_eval_err <- FALSE is_eval_err <- FALSE
if(length(env$eval_err_list) > 0){ if (length(env$eval_err_list) > 0) { is_eval_err <- TRUE }
is_eval_err <- TRUE for (j in seq_along(env$eval_list)) {
}
for(j in 1:length(env$eval_list)) {
eval_err <- NULL eval_err <- NULL
if(is_eval_err){ if (is_eval_err) { eval_err <- env$eval_err_list[[j]] }
eval_err <- env$eval_err_list[[j]] msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
} }
msg <- c(msg, format.eval.string(env$eval_list[[j]],eval_err)) paste0(msg, collapse='\t')
}
return(paste0(msg, collapse='\t'))
} }
cb.print.evaluation <- function(period=1){ cb.print.evaluation <- function(period = 1){
callback <- function(env){ callback <- function(env) {
if(period > 0){ if (period > 0) {
i <- env$iteration i <- env$iteration
if( (i - 1) %% period == 0 if ( (i - 1) %% period == 0
| i == env$begin_iteration | i == env$begin_iteration
| i == env$end_iteration ){ | i == env$end_iteration ) {
cat(merge.eval.string(env), "\n") cat(merge.eval.string(env), "\n")
} }
} }
} }
attr(callback, 'call') <- match.call() attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.print.evaluation' attr(callback, 'name') <- 'cb.print.evaluation'
return(callback) callback
} }
cb.record.evaluation <- function() { cb.record.evaluation <- function() {
callback <- function(env){ callback <- function(env) {
if(length(env$eval_list) <= 0) return() if (length(env$eval_list) <= 0) { return() }
is_eval_err <- FALSE is_eval_err <- FALSE
if(length(env$eval_err_list) > 0){ if (length(env$eval_err_list) > 0) { is_eval_err <- TRUE }
is_eval_err <- TRUE if (length(env$model$record_evals) == 0) {
} for (j in seq_along(env$eval_list)) {
if(length(env$model$record_evals) == 0){
for(j in 1:length(env$eval_list)) {
data_name <- env$eval_list[[j]]$data_name data_name <- env$eval_list[[j]]$data_name
name <- env$eval_list[[j]]$name name <- env$eval_list[[j]]$name
env$model$record_evals$start_iter <- env$begin_iteration env$model$record_evals$start_iter <- env$begin_iteration
if(is.null(env$model$record_evals[[data_name]])){ if (is.null(env$model$record_evals[[data_name]])) {
env$model$record_evals[[data_name]] <- list() env$model$record_evals[[data_name]] <- list()
} }
env$model$record_evals[[data_name]][[name]] <- list() env$model$record_evals[[data_name]][[name]] <- list()
...@@ -137,25 +119,22 @@ cb.record.evaluation <- function() { ...@@ -137,25 +119,22 @@ cb.record.evaluation <- function() {
env$model$record_evals[[data_name]][[name]]$eval_err <- list() env$model$record_evals[[data_name]][[name]]$eval_err <- list()
} }
} }
for(j in 1:length(env$eval_list)) { for (j in seq_along(env$eval_list)) {
eval_res <- env$eval_list[[j]] eval_res <- env$eval_list[[j]]
eval_err <- NULL eval_err <- NULL
if(is_eval_err){ if (is_eval_err) { eval_err <- env$eval_err_list[[j]] }
eval_err <- env$eval_err_list[[j]]
}
data_name <- eval_res$data_name data_name <- eval_res$data_name
name <- eval_res$name name <- eval_res$name
env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value) env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err) env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
} }
} }
attr(callback, 'call') <- match.call() attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.record.evaluation' attr(callback, 'name') <- 'cb.record.evaluation'
return(callback) callback
} }
cb.early.stop <- function(stopping_rounds, verbose=TRUE) { cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# state variables # state variables
factor_to_bigger_better <- NULL factor_to_bigger_better <- NULL
best_iter <- NULL best_iter <- NULL
...@@ -164,45 +143,43 @@ cb.early.stop <- function(stopping_rounds, verbose=TRUE) { ...@@ -164,45 +143,43 @@ cb.early.stop <- function(stopping_rounds, verbose=TRUE) {
eval_len <- NULL eval_len <- NULL
init <- function(env) { init <- function(env) {
eval_len <<- length(env$eval_list) eval_len <<- length(env$eval_list)
if (eval_len == 0) if (eval_len == 0) {
stop("For early stopping, valids must have at least one element") stop("For early stopping, valids must have at least one element")
}
if (verbose) if (isTRUE(verbose)) {
cat("Will train until hasn't improved in ", cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = '')
stopping_rounds, " rounds.\n\n", sep = '') }
factor_to_bigger_better <<- rep(1.0, eval_len) factor_to_bigger_better <<- rep(1.0, eval_len)
best_iter <<- rep(-1, eval_len) best_iter <<- rep(-1, eval_len)
best_score <<- rep(-Inf, eval_len) best_score <<- rep(-Inf, eval_len)
best_msg <<- list() best_msg <<- list()
for(i in 1:eval_len){ for (i in seq_len(eval_len)) {
best_msg <<- c(best_msg, "") best_msg <<- c(best_msg, "")
if(!env$eval_list[[i]]$higher_better){ if (!env$eval_list[[i]]$higher_better) {
factor_to_bigger_better[i] <<- -1.0 factor_to_bigger_better[i] <<- -1.0
} }
} }
} }
callback <- function(env, finalize = FALSE) { callback <- function(env, finalize = FALSE) {
if (is.null(eval_len)) if (is.null(eval_len)) { init(env) }
init(env)
cur_iter <- env$iteration cur_iter <- env$iteration
for(i in 1:eval_len){ for (i in seq_len(eval_len)) {
score <- env$eval_list[[i]]$value * factor_to_bigger_better[i] score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
if(score > best_score[i]){ if (score > best_score[i]) {
best_score[i] <<- score best_score[i] <<- score
best_iter[i] <<- cur_iter best_iter[i] <<- cur_iter
if(verbose){ if (verbose) {
best_msg[[i]] <<- as.character(merge.eval.string(env)) best_msg[[i]] <<- as.character(merge.eval.string(env))
} }
} else { } else {
if(cur_iter - best_iter[i] >= stopping_rounds){ if (cur_iter - best_iter[i] >= stopping_rounds) {
if(!is.null(env$model)){ if (!is.null(env$model)) { env$model$best_iter <- best_iter[i] }
env$model$best_iter <- best_iter[i] if (isTRUE(verbose)) {
} cat("Early stopping, best iteration is:", "\n")
if(verbose){ cat(best_msg[[i]], "\n")
cat('Early stopping, best iteration is:',"\n")
cat(best_msg[[i]],"\n")
} }
env$best_iter <- best_iter[i] env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE env$met_early_stop <- TRUE
...@@ -212,13 +189,11 @@ cb.early.stop <- function(stopping_rounds, verbose=TRUE) { ...@@ -212,13 +189,11 @@ cb.early.stop <- function(stopping_rounds, verbose=TRUE) {
} }
attr(callback, 'call') <- match.call() attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.early.stop' attr(callback, 'name') <- 'cb.early.stop'
return(callback) callback
} }
# Extract callback names from the list of callbacks # Extract callback names from the list of callbacks
callback.names <- function(cb_list) { callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
unlist(lapply(cb_list, function(x) attr(x, 'name')))
}
add.cb <- function(cb_list, cb) { add.cb <- function(cb_list, cb) {
cb_list <- c(cb_list, cb) cb_list <- c(cb_list, cb)
......
Booster <- R6Class( Booster <- R6Class(
"lgb.Booster", "lgb.Booster",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
best_iter = -1, best_iter = -1,
record_evals = list(), record_evals = list(),
finalize = function() { finalize = function() {
if(!lgb.is.null.handle(private$handle)){ if (!lgb.is.null.handle(private$handle)) {
print("free booster handle") cat("freeing booster handle\n")
lgb.call("LGBM_BoosterFree_R", ret=NULL, private$handle) lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
...@@ -20,70 +20,72 @@ Booster <- R6Class( ...@@ -20,70 +20,72 @@ Booster <- R6Class(
handle <- lgb.new.handle() handle <- lgb.new.handle()
if (!is.null(train_set)) { if (!is.null(train_set)) {
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
stop("lgb.Booster: Only can use lgb.Dataset as training data") stop("lgb.Booster: Can only use lgb.Dataset as training data")
} }
handle <- handle <-
lgb.call("LGBM_BoosterCreate_R", ret=handle, train_set$.__enclos_env__$private$get_handle(), params_str) lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str)
private$train_set <- train_set private$train_set <- train_set
private$num_dataset <- 1 private$num_dataset <- 1
private$init_predictor <- train_set$.__enclos_env__$private$predictor private$init_predictor <- train_set$.__enclos_env__$private$predictor
if (!is.null(private$init_predictor)) { if (!is.null(private$init_predictor)) {
lgb.call("LGBM_BoosterMerge_R", ret=NULL, lgb.call("LGBM_BoosterMerge_R", ret = NULL,
handle, handle,
private$init_predictor$.__enclos_env__$private$handle) private$init_predictor$.__enclos_env__$private$handle)
} }
private$is_predicted_cur_iter <- private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
c(private$is_predicted_cur_iter, FALSE)
} else if (!is.null(modelfile)) { } else if (!is.null(modelfile)) {
if (!is.character(modelfile)) { if (!is.character(modelfile)) {
stop("lgb.Booster: Only can use string as model file path") stop("lgb.Booster: Can only use a string as model file path")
} }
handle <- handle <-
lgb.call("LGBM_BoosterCreateFromModelfile_R", lgb.call("LGBM_BoosterCreateFromModelfile_R",
ret=handle, ret = handle,
lgb.c_str(modelfile)) lgb.c_str(modelfile))
} else { } else {
stop( stop(
"lgb.Booster: Need at least one training dataset or model file to create booster instance" "lgb.Booster: Need at least either training dataset or model file to create booster instance"
) )
} }
class(handle) <- "lgb.Booster.handle" class(handle) <- "lgb.Booster.handle"
private$handle <- handle private$handle <- handle
private$num_class <- as.integer(1) private$num_class <- as.integer(1)
private$num_class <- private$num_class <-
lgb.call("LGBM_BoosterGetNumClasses_R", ret=private$num_class, private$handle) lgb.call("LGBM_BoosterGetNumClasses_R", ret = private$num_class, private$handle)
}, },
set_train_data_name = function(name) { set_train_data_name = function(name) {
private$name_train_set <- name private$name_train_set <- name
return(self) self
}, },
add_valid = function(data, name) { add_valid = function(data, name) {
if (!lgb.check.r6.class(data, "lgb.Dataset")) { if (!lgb.check.r6.class(data, "lgb.Dataset")) {
stop("lgb.Booster.add_valid: Only can use lgb.Dataset as validation data") stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
} }
if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) { if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
stop( stop(
"lgb.Booster.add_valid: Add validation data failed, you should use same predictor for these data" "lgb.Booster.add_valid: Failed to add validation data; you should use the same predictor for these data"
) )
} }
if(!is.character(name)){ if (!is.character(name)) {
stop("only can use character as data name") stop("lgb.Booster.add_valid: Can only use characters as data name")
} }
lgb.call("LGBM_BoosterAddValidData_R", ret=NULL, private$handle, data$.__enclos_env__$private$get_handle()) lgb.call("LGBM_BoosterAddValidData_R", ret = NULL, private$handle, data$.__enclos_env__$private$get_handle())
private$valid_sets <- c(private$valid_sets, data) private$valid_sets <- c(private$valid_sets, data)
private$name_valid_sets <- c(private$name_valid_sets, name) private$name_valid_sets <- c(private$name_valid_sets, name)
private$num_dataset <- private$num_dataset + 1 private$num_dataset <- private$num_dataset + 1
private$is_predicted_cur_iter <- private$is_predicted_cur_iter <-
c(private$is_predicted_cur_iter, FALSE) c(private$is_predicted_cur_iter, FALSE)
return(self)
self
}, },
reset_parameter = function(params, ...) { reset_parameter = function(params, ...) {
params <- append(params, list(...)) params <- append(params, list(...))
params_str <- algb.params2str(params) params_str <- algb.params2str(params)
lgb.call("LGBM_BoosterResetParameter_R", ret=NULL, lgb.call("LGBM_BoosterResetParameter_R", ret = NULL,
private$handle, private$handle,
params_str) params_str)
return(self) self
}, },
update = function(train_set = NULL, fobj = NULL) { update = function(train_set = NULL, fobj = NULL) {
if (!is.null(train_set)) { if (!is.null(train_set)) {
...@@ -92,57 +94,51 @@ Booster <- R6Class( ...@@ -92,57 +94,51 @@ Booster <- R6Class(
} }
if (!identical(train_set$predictor, private$init_predictor)) { if (!identical(train_set$predictor, private$init_predictor)) {
stop( stop(
"lgb.Booster.update: Change train_set failed, you should use same predictor for these data" "lgb.Booster.update: Change train_set failed, you should use the same predictor for these data"
) )
} }
lgb.call("LGBM_BoosterResetTrainingData_R", ret=NULL, lgb.call("LGBM_BoosterResetTrainingData_R", ret = NULL,
private$handle, private$handle,
train_set$.__enclos_env__$private$get_handle()) train_set$.__enclos_env__$private$get_handle())
private$train_set = train_set private$train_set = train_set
} }
if (is.null(fobj)) { if (is.null(fobj)) {
ret <- ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
lgb.call("LGBM_BoosterUpdateOneIter_R", ret=NULL, private$handle)
} else { } else {
if (typeof(fobj) != 'closure') { if (!is.function(fobj)) { stop("lgb.Booster.update: fobj should be a function") }
stop("lgb.Booster.update: fobj should be a function")
}
gpair <- fobj(private$inner_predict(1), private$train_set) gpair <- fobj(private$inner_predict(1), private$train_set)
ret <- ret <- lgb.call(
lgb.call( "LGBM_BoosterUpdateOneIterCustom_R", ret = NULL,
"LGBM_BoosterUpdateOneIterCustom_R", ret=NULL,
private$handle, private$handle,
gpair$grad, gpair$grad,
gpair$hess, gpair$hess,
length(gpair$grad) length(gpair$grad)
) )
} }
for (i in 1:length(private$is_predicted_cur_iter)) { for (i in seq_along(private$is_predicted_cur_iter)) {
private$is_predicted_cur_iter[[i]] <- FALSE private$is_predicted_cur_iter[[i]] <- FALSE
} }
return(ret) ret
}, },
rollback_one_iter = function() { rollback_one_iter = function() {
lgb.call("LGBM_BoosterRollbackOneIter_R", ret=NULL, private$handle) lgb.call("LGBM_BoosterRollbackOneIter_R", ret = NULL, private$handle)
for (i in 1:length(private$is_predicted_cur_iter)) { for (i in seq_along(private$is_predicted_cur_iter)) {
private$is_predicted_cur_iter[[i]] <- FALSE private$is_predicted_cur_iter[[i]] <- FALSE
} }
return(self) self
}, },
current_iter = function() { current_iter = function() {
cur_iter <- as.integer(0) cur_iter <- as.integer(0)
return(lgb.call("LGBM_BoosterGetCurrentIteration_R", ret=cur_iter, private$handle)) lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
}, },
eval = function(data, name, feval = NULL) { eval = function(data, name, feval = NULL) {
if (!lgb.check.r6.class(data, "lgb.Dataset")) { if (!lgb.check.r6.class(data, "lgb.Dataset")) {
stop("lgb.Booster.eval: only can use lgb.Dataset to eval") stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
} }
data_idx <- 0 data_idx <- 0
if (identical(data, private$train_set)) { if (identical(data, private$train_set)) { data_idx <- 1 } else {
data_idx <- 1 if (length(private$valid_sets) > 0) {
} else { for (i in seq_along(private$valid_sets)) {
if(length(private$valid_sets) > 0){
for (i in 1:length(private$valid_sets)) {
if (identical(data, private$valid_sets[[i]])) { if (identical(data, private$valid_sets[[i]])) {
data_idx <- i + 1 data_idx <- i + 1
break break
...@@ -154,24 +150,21 @@ Booster <- R6Class( ...@@ -154,24 +150,21 @@ Booster <- R6Class(
self$add_valid(data, name) self$add_valid(data, name)
data_idx <- private$num_dataset data_idx <- private$num_dataset
} }
return(private$inner_eval(name, data_idx, feval)) private$inner_eval(name, data_idx, feval)
}, },
eval_train = function(feval = NULL) { eval_train = function(feval = NULL) {
return(private$inner_eval(private$name_train_set, 1, feval)) private$inner_eval(private$name_train_set, 1, feval)
}, },
eval_valid = function(feval = NULL) { eval_valid = function(feval = NULL) {
ret = list() ret = list()
if(length(private$valid_sets) <= 0) return(ret) if (length(private$valid_sets) <= 0) { return(ret) }
for (i in 1:length(private$valid_sets)) { for (i in seq_along(private$valid_sets)) {
ret <- ret <- append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
} }
return(ret) ret
}, },
save_model = function(filename, num_iteration = NULL) { save_model = function(filename, num_iteration = NULL) {
if (is.null(num_iteration)) { if (is.null(num_iteration)) { num_iteration <- self$best_iter }
num_iteration <- self$best_iter
}
lgb.call( lgb.call(
"LGBM_BoosterSaveModel_R", "LGBM_BoosterSaveModel_R",
ret = NULL, ret = NULL,
...@@ -179,19 +172,15 @@ Booster <- R6Class( ...@@ -179,19 +172,15 @@ Booster <- R6Class(
as.integer(num_iteration), as.integer(num_iteration),
lgb.c_str(filename) lgb.c_str(filename)
) )
return(self) self
}, },
dump_model = function(num_iteration = NULL) { dump_model = function(num_iteration = NULL) {
if (is.null(num_iteration)) { if (is.null(num_iteration)) { num_iteration <- self$best_iter }
num_iteration <- self$best_iter
}
return(
lgb.call.return.str( lgb.call.return.str(
"LGBM_BoosterDumpModel_R", "LGBM_BoosterDumpModel_R",
private$handle, private$handle,
as.integer(num_iteration) as.integer(num_iteration)
) )
)
}, },
predict = function(data, predict = function(data,
num_iteration = NULL, num_iteration = NULL,
...@@ -199,15 +188,11 @@ Booster <- R6Class( ...@@ -199,15 +188,11 @@ Booster <- R6Class(
predleaf = FALSE, predleaf = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE) {
if (is.null(num_iteration)) { if (is.null(num_iteration)) { num_iteration <- self$best_iter }
num_iteration <- self$best_iter
}
predictor <- Predictor$new(private$handle) predictor <- Predictor$new(private$handle)
return(predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)) predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)
}, },
to_predictor = function() { to_predictor = function() { Predictor$new(private$handle) }
Predictor$new(private$handle)
}
), ),
private = list( private = list(
handle = NULL, handle = NULL,
...@@ -224,52 +209,45 @@ Booster <- R6Class( ...@@ -224,52 +209,45 @@ Booster <- R6Class(
higher_better_inner_eval = NULL, higher_better_inner_eval = NULL,
inner_predict = function(idx) { inner_predict = function(idx) {
data_name <- private$name_train_set data_name <- private$name_train_set
if(idx > 1){ if (idx > 1) { data_name <- private$name_valid_sets[[idx - 1]] }
data_name <- private$name_valid_sets[[idx - 1]]
}
if (idx > private$num_dataset) { if (idx > private$num_dataset) {
stop("data_idx should not be greater than num_dataset") stop("data_idx should not be greater than num_dataset")
} }
if (is.null(private$predict_buffer[[data_name]])) { if (is.null(private$predict_buffer[[data_name]])) {
npred <- as.integer(0) npred <- as.integer(0)
npred <- npred <- lgb.call("LGBM_BoosterGetNumPredict_R",
lgb.call("LGBM_BoosterGetNumPredict_R",
ret = npred, ret = npred,
private$handle, private$handle,
as.integer(idx - 1)) as.integer(idx - 1))
private$predict_buffer[[data_name]] <- rep(0.0, npred) private$predict_buffer[[data_name]] <- rep(0.0, npred)
} }
if (!private$is_predicted_cur_iter[[idx]]) { if (!private$is_predicted_cur_iter[[idx]]) {
private$predict_buffer[[data_name]] <- private$predict_buffer[[data_name]] <- lgb.call(
lgb.call(
"LGBM_BoosterGetPredict_R", "LGBM_BoosterGetPredict_R",
ret=private$predict_buffer[[data_name]], ret = private$predict_buffer[[data_name]],
private$handle, private$handle,
as.integer(idx - 1) as.integer(idx - 1)
) )
private$is_predicted_cur_iter[[idx]] <- TRUE private$is_predicted_cur_iter[[idx]] <- TRUE
} }
return(private$predict_buffer[[data_name]]) private$predict_buffer[[data_name]]
}, },
get_eval_info = function() { get_eval_info = function() {
if (is.null(private$eval_names)) { if (is.null(private$eval_names)) {
names <- names <- lgb.call.return.str("LGBM_BoosterGetEvalNames_R", private$handle)
lgb.call.return.str("LGBM_BoosterGetEvalNames_R", private$handle) if (nchar(names) > 0) {
if(nchar(names) > 0){
names <- strsplit(names, "\t")[[1]] names <- strsplit(names, "\t")[[1]]
private$eval_names <- names private$eval_names <- names
private$higher_better_inner_eval <- private$higher_better_inner_eval <- rep(FALSE, length(names))
rep(FALSE, length(names)) for (i in seq_along(names)) {
for (i in 1:length(names)) {
if (startsWith(names[i], "auc") | if (startsWith(names[i], "auc") |
startsWith(names[i], "ndcg")) { startsWith(names[i], "ndcg")) {
private$higher_better_inner_eval[i] <- TRUE private$higher_better_inner_eval[i] <- TRUE
} }
} }
} }
} }
return(private$eval_names) private$eval_names
}, },
inner_eval = function(data_name, data_idx, feval = NULL) { inner_eval = function(data_name, data_idx, feval = NULL) {
if (data_idx > private$num_dataset) { if (data_idx > private$num_dataset) {
...@@ -279,11 +257,10 @@ Booster <- R6Class( ...@@ -279,11 +257,10 @@ Booster <- R6Class(
ret <- list() ret <- list()
if (length(private$eval_names) > 0) { if (length(private$eval_names) > 0) {
tmp_vals <- rep(0.0, length(private$eval_names)) tmp_vals <- rep(0.0, length(private$eval_names))
tmp_vals <- tmp_vals <- lgb.call("LGBM_BoosterGetEval_R", ret = tmp_vals,
lgb.call("LGBM_BoosterGetEval_R", ret=tmp_vals,
private$handle, private$handle,
as.integer(data_idx - 1)) as.integer(data_idx - 1))
for (i in 1:length(private$eval_names)) { for (i in seq_along(private$eval_names)) {
res <- list() res <- list()
res$data_name <- data_name res$data_name <- data_name
res$name <- private$eval_names[i] res$name <- private$eval_names[i]
...@@ -293,30 +270,20 @@ Booster <- R6Class( ...@@ -293,30 +270,20 @@ Booster <- R6Class(
} }
} }
if (!is.null(feval)) { if (!is.null(feval)) {
if (typeof(feval) != 'closure') { if (!is.function(feval)) {
stop("lgb.Booster.eval: feval should be a function") stop("lgb.Booster.eval: feval should be a function")
} }
data <- private$train_set data <- private$train_set
if (data_idx > 1) { if (data_idx > 1) { data <- private$valid_sets[[data_idx - 1]] }
data <- private$valid_sets[[data_idx - 1]]
}
res <- feval(private$inner_predict(data_idx), data) res <- feval(private$inner_predict(data_idx), data)
res$data_name <- data_name res$data_name <- data_name
ret <- append(ret, list(res)) ret <- append(ret, list(res))
} }
return(ret) ret
} }
) )
) )
# internal helper method
lgb.is.Booster <- function(x){
if(lgb.check.r6.class(x, "lgb.Booster")){
return(TRUE)
} else{
return(FALSE)
}
}
#' Predict method for LightGBM model #' Predict method for LightGBM model
#' #'
...@@ -342,6 +309,7 @@ lgb.is.Booster <- function(x){ ...@@ -342,6 +309,7 @@ lgb.is.Booster <- function(x){
#' When \code{predleaf = TRUE}, the output is a matrix object with the #' When \code{predleaf = TRUE}, the output is a matrix object with the
#' number of columns corresponding to the number of trees. #' number of columns corresponding to the number of trees.
#' @examples #' @examples
#' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package='lightgbm') #' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train #' train <- agaricus.train
...@@ -353,18 +321,17 @@ lgb.is.Booster <- function(x){ ...@@ -353,18 +321,17 @@ lgb.is.Booster <- function(x){
#' valids <- list(test=dtest) #' valids <- list(test=dtest)
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' preds <- predict(model, test$data) #' preds <- predict(model, test$data)
#' #' }
#' @rdname predict.lgb.Booster #' @rdname predict.lgb.Booster
#' @export #' @export
predict.lgb.Booster <- function(object, predict.lgb.Booster <- function(object, data,
data,
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE) {
if(!lgb.is.Booster(object)){ if (!lgb.is.Booster(object)) {
stop("predict.lgb.Booster: should input lgb.Booster object") stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
} }
object$predict(data, num_iteration, rawscore, predleaf, header, reshape) object$predict(data, num_iteration, rawscore, predleaf, header, reshape)
} }
...@@ -377,6 +344,7 @@ predict.lgb.Booster <- function(object, ...@@ -377,6 +344,7 @@ predict.lgb.Booster <- function(object,
#' #'
#' @return booster #' @return booster
#' @examples #' @examples
#' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package='lightgbm') #' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train #' train <- agaricus.train
...@@ -389,13 +357,12 @@ predict.lgb.Booster <- function(object, ...@@ -389,13 +357,12 @@ predict.lgb.Booster <- function(object,
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' lgb.save(model, "model.txt") #' lgb.save(model, "model.txt")
#' load_booster <- lgb.load("model.txt") #' load_booster <- lgb.load("model.txt")
#' }
#' @rdname lgb.load #' @rdname lgb.load
#' @export #' @export
lgb.load <- function(filename){ lgb.load <- function(filename){
if(!is.character(filename)){ if (!is.character(filename)) { stop("lgb.load: filename should be character") }
stop("lgb.load: filename should be character") Booster$new(modelfile = filename)
}
Booster$new(modelfile=filename)
} }
#' Save LightGBM model #' Save LightGBM model
...@@ -408,6 +375,7 @@ lgb.load <- function(filename){ ...@@ -408,6 +375,7 @@ lgb.load <- function(filename){
#' #'
#' @return booster #' @return booster
#' @examples #' @examples
#' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package='lightgbm') #' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train #' train <- agaricus.train
...@@ -419,15 +387,12 @@ lgb.load <- function(filename){ ...@@ -419,15 +387,12 @@ lgb.load <- function(filename){
#' valids <- list(test=dtest) #' valids <- list(test=dtest)
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' lgb.save(model, "model.txt") #' lgb.save(model, "model.txt")
#' }
#' @rdname lgb.save #' @rdname lgb.save
#' @export #' @export
lgb.save <- function(booster, filename, num_iteration=NULL){ lgb.save <- function(booster, filename, num_iteration = NULL){
if(!lgb.is.Booster(booster)){ if (!lgb.is.Booster(booster)) { stop("lgb.save: booster should be an ", sQuote("lgb.Booster")) }
stop("lgb.save: should input lgb.Booster object") if (!is.character(filename)) { stop("lgb.save: filename should be a character") }
}
if(!is.character(filename)){
stop("lgb.save: filename should be character")
}
booster$save_model(filename, num_iteration) booster$save_model(filename, num_iteration)
} }
...@@ -440,6 +405,7 @@ lgb.save <- function(booster, filename, num_iteration=NULL){ ...@@ -440,6 +405,7 @@ lgb.save <- function(booster, filename, num_iteration=NULL){
#' #'
#' @return json format of model #' @return json format of model
#' @examples #' @examples
#' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package='lightgbm') #' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train #' train <- agaricus.train
...@@ -451,12 +417,11 @@ lgb.save <- function(booster, filename, num_iteration=NULL){ ...@@ -451,12 +417,11 @@ lgb.save <- function(booster, filename, num_iteration=NULL){
#' valids <- list(test=dtest) #' valids <- list(test=dtest)
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' json_model <- lgb.dump(model) #' json_model <- lgb.dump(model)
#' }
#' @rdname lgb.dump #' @rdname lgb.dump
#' @export #' @export
lgb.dump <- function(booster, num_iteration=NULL){ lgb.dump <- function(booster, num_iteration = NULL){
if(!lgb.is.Booster(booster)){ if (!lgb.is.Booster(booster)) { stop("lgb.save: booster should be an ", sQuote("lgb.Booster")) }
stop("lgb.dump: should input lgb.Booster object")
}
booster$dump_model(num_iteration) booster$dump_model(num_iteration)
} }
...@@ -469,32 +434,30 @@ lgb.dump <- function(booster, num_iteration=NULL){ ...@@ -469,32 +434,30 @@ lgb.dump <- function(booster, num_iteration=NULL){
#' @param iters iterations, NULL will return all #' @param iters iterations, NULL will return all
#' @param is_err TRUE will return evaluation error instead #' @param is_err TRUE will return evaluation error instead
#' @return vector of evaluation result #' @return vector of evaluation result
#'
#' @rdname lgb.get.eval.result #' @rdname lgb.get.eval.result
#' @export #' @export
lgb.get.eval.result <- function(booster, data_name, eval_name, iters=NULL, is_err=FALSE){ lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
if(!lgb.is.Booster(booster)){ if (!lgb.is.Booster(booster)) {
stop("lgb.get.eval.result: only can use booster to get eval result") stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
} }
if(!is.character(data_name) | !is.character(eval_name)){ if (!is.character(data_name) || !is.character(eval_name)) {
stop("lgb.get.eval.result: data_name and eval_name should be character") stop("lgb.get.eval.result: data_name and eval_name should be characters")
} }
if(is.null(booster$record_evals[[data_name]])){ if (is.null(booster$record_evals[[data_name]])) {
stop("lgb.get.eval.result: wrong data name") stop("lgb.get.eval.result: wrong data name")
} }
if(is.null(booster$record_evals[[data_name]][[eval_name]])){ if (is.null(booster$record_evals[[data_name]][[eval_name]])) {
stop("lgb.get.eval.result: wrong eval name") stop("lgb.get.eval.result: wrong eval name")
} }
result <- booster$record_evals[[data_name]][[eval_name]]$eval result <- booster$record_evals[[data_name]][[eval_name]]$eval
if(is_err){ if (is_err) {
result <- booster$record_evals[[data_name]][[eval_name]]$eval_err result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
} }
if(is.null(iters)){ if (is.null(iters)) {
return(as.numeric(result)) return(as.numeric(result))
} }
iters <- as.integer(iters) iters <- as.integer(iters)
delta <- booster$record_evals$start_iter - 1 delta <- booster$record_evals$start_iter - 1
iters <- iters - delta iters <- iters - delta
return(as.numeric(result[iters])) as.numeric(result[iters])
} }
This diff is collapsed.
Predictor <- R6Class( Predictor <- R6Class(
"lgb.Predictor", "lgb.Predictor",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
finalize = function() { finalize = function() {
if(private$need_free_handle & !lgb.is.null.handle(private$handle)){ if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
print("free booster handle") cat("free booster handle\n")
lgb.call("LGBM_BoosterFree_R", ret=NULL, private$handle) lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
initialize = function(modelfile) { initialize = function(modelfile) {
handle <- lgb.new.handle() handle <- lgb.new.handle()
if(typeof(modelfile) == "character") { if (is.character(modelfile)) {
handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret=handle, lgb.c_str(modelfile)) handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile))
private$need_free_handle = TRUE private$need_free_handle <- TRUE
} else if (class(modelfile) == "lgb.Booster.handle") { } else if (is(modelfile, "lgb.Booster.handle")) {
handle <- modelfile handle <- modelfile
private$need_free_handle = FALSE private$need_free_handle <- FALSE
} else { } else {
stop("lgb.Predictor: modelfile must be either character filename, or lgb.Booster.handle") stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
} }
class(handle) <- "lgb.Booster.handle" class(handle) <- "lgb.Booster.handle"
private$handle <- handle private$handle <- handle
}, },
current_iter = function() { current_iter = function() {
cur_iter <- as.integer(0) cur_iter <- 0L
return(lgb.call("LGBM_BoosterGetCurrentIteration_R", ret=cur_iter, private$handle)) lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
}, },
predict = function(data, predict = function(data, num_iteration = NULL, rawscore = FALSE,
num_iteration = NULL, rawscore = FALSE, predleaf = FALSE, header = FALSE, predleaf = FALSE, header = FALSE, reshape = FALSE) {
reshape = FALSE) {
if (is.null(num_iteration)) {
num_iteration <- -1
}
if (is.null(num_iteration)) { num_iteration <- -1 }
num_row <- 0 num_row <- 0
if (typeof(data) == "character") { if (is.character(data)) {
tmp_filename <- tempfile(pattern = "lightgbm_") tmp_filename <- tempfile(pattern = "lightgbm_")
lgb.call("LGBM_BoosterPredictForFile_R", ret=NULL, private$handle, data, as.integer(header), on.exit(unlink(tmp_filename), add = TRUE)
lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data,
as.integer(header),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration), as.integer(num_iteration),
lgb.c_str(tmp_filename)) lgb.c_str(tmp_filename))
preds <- read.delim(tmp_filename, header=FALSE, seq="\t") preds <- read.delim(tmp_filename, header = FALSE, seq = "\t")
num_row <- nrow(preds) num_row <- nrow(preds)
preds <- as.vector(t(preds)) preds <- as.vector(t(preds))
# delete temp file
if(file.exists(tmp_filename)) { file.remove(tmp_filename) }
} else { } else {
num_row <- nrow(data) num_row <- nrow(data)
npred <- as.integer(0) npred <- as.integer(0)
npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret=npred, npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret = npred,
private$handle, private$handle,
as.integer(num_row), as.integer(num_row),
as.integer(rawscore), as.integer(rawscore),
...@@ -61,7 +56,7 @@ Predictor <- R6Class( ...@@ -61,7 +56,7 @@ Predictor <- R6Class(
# allocte space for prediction # allocte space for prediction
preds <- rep(0.0, npred) preds <- rep(0.0, npred)
if (is.matrix(data)) { if (is.matrix(data)) {
preds <- lgb.call("LGBM_BoosterPredictForMat_R", ret=preds, preds <- lgb.call("LGBM_BoosterPredictForMat_R", ret = preds,
private$handle, private$handle,
data, data,
as.integer(nrow(data)), as.integer(nrow(data)),
...@@ -69,8 +64,8 @@ Predictor <- R6Class( ...@@ -69,8 +64,8 @@ Predictor <- R6Class(
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration)) as.integer(num_iteration))
} else if (class(data) == "dgCMatrix") { } else if (is(data, "dgCMatrix")) {
preds <- lgb.call("LGBM_BoosterPredictForCSC_R", ret=preds, preds <- lgb.call("LGBM_BoosterPredictForCSC_R", ret = preds,
private$handle, private$handle,
data@p, data@p,
data@i, data@i,
...@@ -82,23 +77,17 @@ Predictor <- R6Class( ...@@ -82,23 +77,17 @@ Predictor <- R6Class(
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration)) as.integer(num_iteration))
} else { } else {
stop(paste("predict: does not support to predict from ", stop("predict: cannot predict on data of class ", sQuote(class(data)))
typeof(data)))
} }
} }
if (length(preds) %% num_row != 0) { if (length(preds) %% num_row != 0) {
stop("predict: prediction length ", length(preds)," is not multiple of nrows(data) ", num_row) stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row))
} }
npred_per_case <- length(preds) / num_row npred_per_case <- length(preds) / num_row
if (reshape && npred_per_case > 1) { if (reshape && npred_per_case > 1) { preds <- matrix(preds, ncol = npred_per_case) }
preds <- matrix(preds, ncol = npred_per_case) preds
}
return(preds)
} }
), ),
private = list( private = list( handle = NULL, need_free_handle = FALSE )
handle = NULL,
need_free_handle = FALSE
)
) )
CVBooster <- R6Class( CVBooster <- R6Class(
"lgb.CVBooster", "lgb.CVBooster",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
best_iter = -1, best_iter = -1,
record_evals = list(), record_evals = list(),
boosters=list(), boosters = list(),
initialize=function(x){ initialize = function(x) {
self$boosters <- x self$boosters <- x
}, },
reset_parameter=function(new_paramas){ reset_parameter = function(new_params) {
for(x in boosters){ for (x in boosters) { x$reset_parameter(new_params) }
x$reset_parameter(new_paramas) self
}
return(self)
} }
) )
) )
...@@ -55,25 +53,26 @@ CVBooster <- R6Class( ...@@ -55,25 +53,26 @@ CVBooster <- R6Class(
#' @param ... other parameters, see parameters.md for more informations #' @param ... other parameters, see parameters.md for more informations
#' @return a trained booster model \code{lgb.Booster}. #' @return a trained booster model \code{lgb.Booster}.
#' @examples #' @examples
#' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package='lightgbm') #' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train #' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' params <- list(objective="regression", metric="l2") #' params <- list(objective="regression", metric="l2")
#' model <- lgb.cv(params, dtrain, 10, nfold=5, min_data=1, learning_rate=1, early_stopping_rounds=10) #' model <- lgb.cv(params, dtrain, 10, nfold=5, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' #' }
#' @rdname lgb.train #' @rdname lgb.train
#' @export #' @export
lgb.cv <- function(params=list(), data, nrounds=10, nfold=3, lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3,
label = NULL, weight = NULL, label = NULL, weight = NULL,
obj=NULL, eval=NULL, obj = NULL, eval = NULL,
verbose=1, eval_freq=1L, showsd = TRUE, verbose = 1, eval_freq = 1L, showsd = TRUE,
stratified = TRUE, folds = NULL, stratified = TRUE, folds = NULL,
init_model=NULL, init_model = NULL,
colnames=NULL, colnames= NULL,
categorical_feature=NULL, categorical_feature = NULL,
early_stopping_rounds=NULL, early_stopping_rounds = NULL,
callbacks=list(), ...) { callbacks = list(), ...) {
addiction_params <- list(...) addiction_params <- list(...)
params <- append(params, addiction_params) params <- append(params, addiction_params)
params$verbose <- verbose params$verbose <- verbose
...@@ -81,130 +80,113 @@ lgb.cv <- function(params=list(), data, nrounds=10, nfold=3, ...@@ -81,130 +80,113 @@ lgb.cv <- function(params=list(), data, nrounds=10, nfold=3,
params <- lgb.check.eval(params, eval) params <- lgb.check.eval(params, eval)
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if(typeof(params$objective) == "closure"){ if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
} }
if (typeof(eval) == "closure"){ if (is.function(eval)) { feval <- eval }
feval <- eval
}
lgb.check.params(params) lgb.check.params(params)
predictor <- NULL predictor <- NULL
if(is.character(init_model)){ if (is.character(init_model)) {
predictor <- Predictor$new(init_model) predictor <- Predictor$new(init_model)
} else if(lgb.is.Booster(init_model)) { } else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor() predictor <- init_model$to_predictor()
} }
begin_iteration <- 1 begin_iteration <- 1
if(!is.null(predictor)){ if (!is.null(predictor)) {
begin_iteration <- predictor$current_iter() + 1 begin_iteration <- predictor$current_iter() + 1
} }
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
# check dataset if (!lgb.is.Dataset(data)) {
if(!lgb.is.Dataset(data)){ if (is.null(label)) { stop("Labels must be provided for lgb.cv") }
if(is.null(label)){ data <- lgb.Dataset(data, label = label)
stop("Labels must be provided for lgb.cv")
}
data <- lgb.Dataset(data, label=label)
} }
if(!is.null(weight)) data$set_info("weight", weight) if (!is.null(weight)) { data$set_info("weight", weight) }
data$update_params(params) data$update_params(params)
data$.__enclos_env__$private$set_predictor(predictor) data$.__enclos_env__$private$set_predictor(predictor)
if(!is.null(colnames)){ if (!is.null(colnames)) { data$set_colnames(colnames) }
data$set_colnames(colnames)
}
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
data$construct() data$construct()
# CV folds if (!is.null(folds)) {
if(!is.null(folds)) { if (!is.list(folds) || length(folds) < 2)
if(class(folds) != "list" || length(folds) < 2) stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold")
stop("'folds' must be a list with 2 or more elements that are vectors of indices for each CV-fold")
nfold <- length(folds) nfold <- length(folds)
} else { } else {
if (nfold <= 1) if (nfold <= 1) { stop(sQuote("nfold"), " must be > 1") }
stop("'nfold' must be > 1")
folds <- generate.cv.folds(nfold, nrow(data), stratified, getinfo(data, 'label'), params) folds <- generate.cv.folds(nfold, nrow(data), stratified, getinfo(data, 'label'), params)
} }
# process callbacks if (eval_freq > 0) {
if(eval_freq > 0){
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
if (verbose > 0) { if (verbose > 0) { callbacks <- add.cb(callbacks, cb.record.evaluation()) }
callbacks <- add.cb(callbacks, cb.record.evaluation())
}
# Early stopping callback
if (!is.null(early_stopping_rounds)) { if (!is.null(early_stopping_rounds)) {
if(early_stopping_rounds > 0){ if (early_stopping_rounds > 0) {
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose=verbose)) callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose = verbose))
} }
} }
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
# construct booster # construct booster
bst_folds <- lapply(seq_along(folds), function(k) {
bst_folds <- lapply(1:length(folds), function(k) {
dtest <- slice(data, folds[[k]]) dtest <- slice(data, folds[[k]])
dtrain <- slice(data, unlist(folds[-k])) dtrain <- slice(data, unlist(folds[-k]))
booster <- Booster$new(params, dtrain) booster <- Booster$new(params, dtrain)
booster$add_valid(dtest, "valid") booster$add_valid(dtest, "valid")
list(booster=booster) list(booster = booster)
}) })
cv_booster <- CVBooster$new(bst_folds) cv_booster <- CVBooster$new(bst_folds)
# callback env # callback env
env <- CB_ENV$new() env <- CB_ENV$new()
env$model <- cv_booster env$model <- cv_booster
env$begin_iteration <- begin_iteration env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
#start training #start training
for(i in begin_iteration:end_iteration){ for (i in seq(from = begin_iteration, to = end_iteration)) {
env$iteration <- i env$iteration <- i
env$eval_list <- list() env$eval_list <- list()
for (f in cb$pre_iter) f(env) for (f in cb$pre_iter) { f(env) }
# update one iter # update one iter
msg <- lapply(cv_booster$boosters, function(fd) { msg <- lapply(cv_booster$boosters, function(fd) {
fd$booster$update(fobj=fobj) fd$booster$update(fobj = fobj)
fd$booster$eval_valid(feval=feval) fd$booster$eval_valid(feval = feval)
}) })
merged_msg <- lgb.merge.cv.result(msg) merged_msg <- lgb.merge.cv.result(msg)
env$eval_list <- merged_msg$eval_list env$eval_list <- merged_msg$eval_list
if(showsd) env$eval_err_list <- merged_msg$eval_err_list if(showsd) { env$eval_err_list <- merged_msg$eval_err_list }
for (f in cb$post_iter) { f(env) }
for (f in cb$post_iter) f(env)
# met early stopping # met early stopping
if(env$met_early_stop) break if (env$met_early_stop) break
} }
return(cv_booster) cv_booster
} }
# Generates random (stratified if needed) CV folds # Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) { generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
# cannot do it for rank # cannot do it for rank
if (exists('objective', where=params) && if (exists('objective', where = params) &&
is.character(params$objective) && is.character(params$objective) &&
params$objective == 'lambdarank') { params$objective == 'lambdarank') {
stop("\n\tAutomatic generation of CV-folds is not implemented for lambdarank!\n", stop("\n\tAutomatic generation of CV-folds is not implemented for lambdarank!\n",
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n") "\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
} }
# shuffle # shuffle
rnd_idx <- sample(1:nrows) rnd_idx <- sample(seq_len(nrows))
if (stratified && if (isTRUE(stratified) &&
length(label) == length(rnd_idx)) { length(label) == length(rnd_idx)) {
y <- label[rnd_idx] y <- label[rnd_idx]
y <- factor(y) y <- factor(y)
...@@ -213,20 +195,19 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) { ...@@ -213,20 +195,19 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
# make simple non-stratified folds # make simple non-stratified folds
kstep <- length(rnd_idx) %/% nfold kstep <- length(rnd_idx) %/% nfold
folds <- list() folds <- list()
for (i in 1:(nfold - 1)) { for (i in seq_len(nfold - 1)) {
folds[[i]] <- rnd_idx[1:kstep] folds[[i]] <- rnd_idx[seq_len(kstep)]
rnd_idx <- rnd_idx[-(1:kstep)] rnd_idx <- rnd_idx[-(seq_len(kstep))]
} }
folds[[nfold]] <- rnd_idx folds[[nfold]] <- rnd_idx
} }
return(folds) folds
} }
# Creates CV folds stratified by the values of y. # Creates CV folds stratified by the values of y.
# It was borrowed from caret::lgb.stratified.folds and simplified # It was borrowed from caret::lgb.stratified.folds and simplified
# by always returning an unnamed list of fold indices. # by always returning an unnamed list of fold indices.
lgb.stratified.folds <- function(y, k = 10) lgb.stratified.folds <- function(y, k = 10) {
{
if (is.numeric(y)) { if (is.numeric(y)) {
## Group the numeric data based on their magnitudes ## Group the numeric data based on their magnitudes
## and sample within those groups. ## and sample within those groups.
...@@ -236,11 +217,10 @@ lgb.stratified.folds <- function(y, k = 10) ...@@ -236,11 +217,10 @@ lgb.stratified.folds <- function(y, k = 10)
## groups. The number of groups will depend on the ## groups. The number of groups will depend on the
## ratio of the number of folds to the sample size. ## ratio of the number of folds to the sample size.
## At most, we will use quantiles. If the sample ## At most, we will use quantiles. If the sample
## is too small, we just do regular unstratified ## is too small, we just do regular unstratified CV
## CV
cuts <- floor(length(y) / k) cuts <- floor(length(y) / k)
if (cuts < 2) cuts <- 2 if (cuts < 2) { cuts <- 2 }
if (cuts > 5) cuts <- 5 if (cuts > 5) { cuts <- 5 }
y <- cut(y, y <- cut(y,
unique(stats::quantile(y, probs = seq(0, 1, length = cuts))), unique(stats::quantile(y, probs = seq(0, 1, length = cuts))),
include.lowest = TRUE) include.lowest = TRUE)
...@@ -256,48 +236,41 @@ lgb.stratified.folds <- function(y, k = 10) ...@@ -256,48 +236,41 @@ lgb.stratified.folds <- function(y, k = 10)
## For each class, balance the fold allocation as far ## For each class, balance the fold allocation as far
## as possible, then resample the remainder. ## as possible, then resample the remainder.
## The final assignment of folds is also randomized. ## The final assignment of folds is also randomized.
for (i in 1:length(numInClass)) { for (i in seq_along(numInClass)) {
## create a vector of integers from 1:k as many times as possible without ## create a vector of integers from 1:k as many times as possible without
## going over the number of samples in the class. Note that if the number ## going over the number of samples in the class. Note that if the number
## of samples in a class is less than k, nothing is producd here. ## of samples in a class is less than k, nothing is producd here.
seqVector <- rep(1:k, numInClass[i] %/% k) seqVector <- rep(seq_len(k), numInClass[i] %/% k)
## add enough random integers to get length(seqVector) == numInClass[i] ## add enough random integers to get length(seqVector) == numInClass[i]
if (numInClass[i] %% k > 0) seqVector <- c(seqVector, sample(1:k, numInClass[i] %% k)) if (numInClass[i] %% k > 0) {
seqVector <- c(seqVector, sample(seq_len(k), numInClass[i] %% k))
}
## shuffle the integers for fold assignment and assign to this classes's data ## shuffle the integers for fold assignment and assign to this classes's data
foldVector[which(y == dimnames(numInClass)$y[i])] <- sample(seqVector) foldVector[y == dimnames(numInClass)$y[i]] <- sample(seqVector)
} }
} else { } else {
foldVector <- seq(along = y) foldVector <- seq(along = y)
} }
out <- split(seq(along = y), foldVector) out <- split(seq(along = y), foldVector)
names(out) <- NULL `names<-`(out, NULL)
out
} }
lgb.merge.cv.result <- function(msg, showsd=TRUE){ lgb.merge.cv.result <- function(msg, showsd = TRUE){
if(length(msg) == 0){ if (length(msg) == 0) { stop("lgb.cv: size of cv result error") }
stop("lgb.cv: size of cv result error")
}
eval_len <- length(msg[[1]]) eval_len <- length(msg[[1]])
if(eval_len == 0){ if (eval_len == 0) { stop("lgb.cv: should provide at least one metric for CV") }
stop("lgb.cv: should provide at least metric for CV") eval_result <- lapply(seq_len(eval_len), function(j) {
} as.numeric(lapply(seq_along(msg), function(i) { msg[[i]][[j]]$value }))
eval_result <- lapply(1:eval_len, function(j) {
as.numeric(lapply(1:length(msg), function(i){
msg[[i]][[j]]$value
}))
}) })
ret_eval <- msg[[1]] ret_eval <- msg[[1]]
for(j in 1:eval_len){ for (j in seq_len(eval_len)) { ret_eval[[j]]$value <- mean(eval_result[[j]]) }
ret_eval[[j]]$value <- mean(eval_result[[j]])
}
ret_eval_err <- NULL ret_eval_err <- NULL
if(showsd){ if (showsd) {
for(j in 1:eval_len){ for (j in seq_len(eval_len)) {
ret_eval_err <- c( ret_eval_err, sqrt( mean(eval_result[[j]]^2) - mean(eval_result[[j]])^2 )) ret_eval_err <- c( ret_eval_err, sqrt( mean(eval_result[[j]]^2) - mean(eval_result[[j]])^2 ))
} }
ret_eval_err <- as.list(ret_eval_err) ret_eval_err <- as.list(ret_eval_err)
} }
return(list(eval_list=ret_eval, eval_err_list=ret_eval_err)) list(eval_list = ret_eval, eval_err_list = ret_eval_err)
} }
#' Main training logic for LightGBM #' Main training logic for LightGBM
#' #'
#' Main training logic for LightGBM
#'
#' @param params List of parameters #' @param params List of parameters
#' @param data a \code{lgb.Dataset} object, used for training #' @param data a \code{lgb.Dataset} object, used for training
#' @param nrounds number of training rounds #' @param nrounds number of training rounds
#' @param valids a list of \code{lgb.Dataset} object, used for validation #' @param valids a list of \code{lgb.Dataset} objects, used for validation
#' @param obj objective function, can be character or custom objective function #' @param obj objective function, can be character or custom objective function
#' @param eval evaluation function, can be (list of) character or custom eval function #' @param eval evaluation function, can be (a list of) character or custom eval function
#' @param verbose verbosity for output #' @param verbose verbosity for output
#' if verbose > 0 , also will record iteration message to booster$record_evals #' if \code{verbose > 0}, also will record iteration message to \code{booster$record_evals}
#' @param eval_freq evalutaion output frequence #' @param eval_freq evalutaion output frequency
#' @param init_model path of model file of \code{lgb.Booster} object, will continue train from this model #' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset #' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int #' @param categorical_feature list of str or int
#' type int represents index, #' type int represents index,
...@@ -27,6 +25,7 @@ ...@@ -27,6 +25,7 @@
#' @param ... other parameters, see parameters.md for more informations #' @param ... other parameters, see parameters.md for more informations
#' @return a trained booster model \code{lgb.Booster}. #' @return a trained booster model \code{lgb.Booster}.
#' @examples #' @examples
#' \dontrun{
#' library(lightgbm) #' library(lightgbm)
#' data(agaricus.train, package='lightgbm') #' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train #' train <- agaricus.train
...@@ -37,72 +36,71 @@ ...@@ -37,72 +36,71 @@
#' params <- list(objective="regression", metric="l2") #' params <- list(objective="regression", metric="l2")
#' valids <- list(test=dtest) #' valids <- list(test=dtest)
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' #' }
#' @rdname lgb.train #' @rdname lgb.train
#' @export #' @export
lgb.train <- function(params=list(), data, nrounds=10, lgb.train <- function(params = list(), data, nrounds = 10,
valids=list(), valids = list(),
obj=NULL, eval=NULL, obj = NULL,
verbose=1, eval_freq=1L, eval = NULL,
init_model=NULL, verbose = 1,
colnames=NULL, eval_freq = 1L,
categorical_feature=NULL, init_model = NULL,
early_stopping_rounds=NULL, colnames = NULL,
callbacks=list(), ...) { categorical_feature = NULL,
addiction_params <- list(...) early_stopping_rounds = NULL,
params <- append(params, addiction_params) callbacks = list(), ...) {
additional_params <- list(...)
params <- append(params, additional_params)
params$verbose <- verbose params$verbose <- verbose
params <- lgb.check.obj(params, obj) params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval) params <- lgb.check.eval(params, eval)
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if(typeof(params$objective) == "closure"){ if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
} }
if (typeof(eval) == "closure"){ if (is.function(eval)) { feval <- eval }
feval <- eval
}
lgb.check.params(params) lgb.check.params(params)
predictor <- NULL predictor <- NULL
if(is.character(init_model)){ if (is.character(init_model)) {
predictor <- Predictor$new(init_model) predictor <- Predictor$new(init_model)
} else if(lgb.is.Booster(init_model)) { } else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor() predictor <- init_model$to_predictor()
} }
begin_iteration <- 1 begin_iteration <- 1
if(!is.null(predictor)){ if (!is.null(predictor)) {
begin_iteration <- predictor$current_iter() + 1 begin_iteration <- predictor$current_iter() + 1
} }
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
# check dataset # check dataset
if(!lgb.is.Dataset(data)){ if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object") stop("lgb.train: data only accepts lgb.Dataset object")
} }
if (length(valids) > 0) { if (length(valids) > 0) {
if (typeof(valids) != "list" || if (!is.list(valids) || !all(sapply(valids, lgb.is.Dataset))) {
!all(sapply(valids, lgb.is.Dataset))) stop("lgb.train: valids must be a list of lgb.Dataset elements")
stop("valids must be a list of lgb.Dataset elements") }
evnames <- names(valids) evnames <- names(valids)
if (is.null(evnames) || any(evnames == "")) if (is.null(evnames) || !all(nzchar(evnames))) {
stop("each element of the valids must have a name tag") stop("lgb.train: each element of the valids must have a name tag")
}
} }
data$update_params(params) data$update_params(params)
data$.__enclos_env__$private$set_predictor(predictor) data$.__enclos_env__$private$set_predictor(predictor)
if(!is.null(colnames)){ if (!is.null(colnames)) { data$set_colnames(colnames) }
data$set_colnames(colnames)
}
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
data$construct() data$construct()
vaild_contain_train <- FALSE vaild_contain_train <- FALSE
train_data_name <- "train" train_data_name <- "train"
reduced_valid_sets <- list() reduced_valid_sets <- list()
if(length(valids) > 0){ if (length(valids) > 0) {
for (key in names(valids)) { for (key in names(valids)) {
valid_data <- valids[[key]] valid_data <- valids[[key]]
if(identical(data, valid_data)){ if (identical(data, valid_data)) {
vaild_contain_train <- TRUE vaild_contain_train <- TRUE
train_data_name <- key train_data_name <- key
next next
...@@ -113,7 +111,7 @@ lgb.train <- function(params=list(), data, nrounds=10, ...@@ -113,7 +111,7 @@ lgb.train <- function(params=list(), data, nrounds=10,
} }
} }
# process callbacks # process callbacks
if(eval_freq > 0){ if (eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
...@@ -123,55 +121,47 @@ lgb.train <- function(params=list(), data, nrounds=10, ...@@ -123,55 +121,47 @@ lgb.train <- function(params=list(), data, nrounds=10,
# Early stopping callback # Early stopping callback
if (!is.null(early_stopping_rounds)) { if (!is.null(early_stopping_rounds)) {
if(early_stopping_rounds > 0){ if (early_stopping_rounds > 0) {
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose=verbose)) callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose = verbose))
} }
} }
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
# construct booster # construct booster
booster <- Booster$new(params=params, train_set=data) booster <- Booster$new(params = params, train_set = data)
if(vaild_contain_train){ if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
booster$set_train_data_name(train_data_name)
}
for (key in names(reduced_valid_sets)) { for (key in names(reduced_valid_sets)) {
booster$add_valid(reduced_valid_sets[[key]], key) booster$add_valid(reduced_valid_sets[[key]], key)
} }
# callback env # callback env
env <- CB_ENV$new() env <- CB_ENV$new()
env$model <- booster env$model <- booster
env$begin_iteration <- begin_iteration env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
#start training #start training
for(i in begin_iteration:end_iteration){ for (i in seq(from = begin_iteration, to = end_iteration)) {
env$iteration <- i env$iteration <- i
env$eval_list <- list() env$eval_list <- list()
for (f in cb$pre_iter) f(env) for (f in cb$pre_iter) { f(env) }
# update one iter # update one iter
booster$update(fobj=fobj) booster$update(fobj = fobj)
# collect eval result # collect eval result
eval_list <- list() eval_list <- list()
if(length(valids) > 0){ if (length(valids) > 0) {
if(vaild_contain_train){ if (vaild_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval=feval)) eval_list <- append(eval_list, booster$eval_train(feval = feval))
} }
eval_list <- append(eval_list, booster$eval_valid(feval=feval)) eval_list <- append(eval_list, booster$eval_valid(feval = feval))
} }
env$eval_list <- eval_list env$eval_list <- eval_list
for (f in cb$post_iter) { f(env) }
for (f in cb$post_iter) f(env)
# met early stopping # met early stopping
if(env$met_early_stop) break if (env$met_early_stop) break
} }
return(booster) booster
} }
# Simple interface for training an lightgbm model. #' Simple interface for training an lightgbm model.
# Its documentation is combined with lgb.train. #' Its documentation is combined with lgb.train.
# #'
#' @rdname lgb.train #' @rdname lgb.train
#' @export #' @export
lightgbm <- function(data, label = NULL, weight = NULL, lightgbm <- function(data, label = NULL, weight = NULL,
params = list(), nrounds=10, params = list(), nrounds = 10,
verbose = 1, eval_freq = 1L, verbose = 1, eval_freq = 1L,
early_stopping_rounds = NULL, early_stopping_rounds = NULL,
save_name = "lightgbm.model", save_name = "lightgbm.model",
init_model = NULL, callbacks = list(), ...) { init_model = NULL, callbacks = list(), ...) {
dtrain <- data dtrain <- data
if(!lgb.is.Dataset(dtrain)) { if (!lgb.is.Dataset(dtrain)) {
dtrain <- lgb.Dataset(data, label=label, weight=weight) dtrain <- lgb.Dataset(data, label = label, weight = weight)
} }
valids <- list() valids <- list()
if (verbose > 0) if (verbose > 0) { valids$train = dtrain }
valids$train = dtrain
bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq=eval_freq, bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq = eval_freq,
early_stopping_rounds = early_stopping_rounds, early_stopping_rounds = early_stopping_rounds,
init_model = init_model, callbacks = callbacks, ...) init_model = init_model, callbacks = callbacks, ...)
bst$save_model(save_name) bst$save_model(save_name)
return(bst) bst
} }
#' Training part from Mushroom Data Set #' Training part from Mushroom Data Set
......
lgb.new.handle <- function() { lgb.is.Booster <- function(x) { lgb.check.r6.class(x, "lgb.Booster") }
# use 64bit data to store address
return(0.0) lgb.is.Dataset <- function(x) { lgb.check.r6.class(x, "lgb.Dataset") }
}
lgb.is.null.handle <- function(x) { # use 64bit data to store address
if (is.null(x)) { lgb.new.handle <- function() { 0.0 }
return(TRUE)
} lgb.is.null.handle <- function(x) { is.null(x) || x == 0 }
if (x == 0) {
return(TRUE)
}
return(FALSE)
}
lgb.encode.char <- function(arr, len) { lgb.encode.char <- function(arr, len) {
if (typeof(arr) != "raw") { if (!is.raw(arr)) {
stop("lgb.encode.char: only can encode from raw type") stop("lgb.encode.char: Can only encode from raw type")
} }
return(rawToChar(arr[1:len])) rawToChar(arr[seq_len(len)])
} }
lgb.call <- function(fun_name, ret, ...) { lgb.call <- function(fun_name, ret, ...) {
call_state <- as.integer(0) call_state <- 0L
if (!is.null(ret)) { if (!is.null(ret)) {
call_state <- call_state <- .Call(fun_name, ..., ret, call_state, PACKAGE = "lightgbm")
.Call(fun_name, ..., ret, call_state , PACKAGE = "lightgbm")
} else { } else {
call_state <- .Call(fun_name, ..., call_state , PACKAGE = "lightgbm") call_state <- .Call(fun_name, ..., call_state, PACKAGE = "lightgbm")
} }
if (call_state != as.integer(0)) { if (call_state != 0L) {
buf_len <- as.integer(200) buf_len <- 200L
act_len <- as.integer(0) act_len <- 0L
err_msg <- raw(buf_len) err_msg <- raw(buf_len)
err_msg <- err_msg <- .Call("LGBM_GetLastError_R", buf_len, act_len, err_msg, PACKAGE = "lightgbm")
.Call("LGBM_GetLastError_R", buf_len, act_len, err_msg, PACKAGE = "lightgbm")
if (act_len > buf_len) { if (act_len > buf_len) {
buf_len <- act_len buf_len <- act_len
err_msg <- raw(buf_len) err_msg <- raw(buf_len)
err_msg <- err_msg <- .Call("LGBM_GetLastError_R",
.Call("LGBM_GetLastError_R",
buf_len, buf_len,
act_len, act_len,
err_msg, err_msg,
...@@ -45,13 +37,13 @@ lgb.call <- function(fun_name, ret, ...) { ...@@ -45,13 +37,13 @@ lgb.call <- function(fun_name, ret, ...) {
} }
stop(paste0("api error: ", lgb.encode.char(err_msg, act_len))) stop(paste0("api error: ", lgb.encode.char(err_msg, act_len)))
} }
return(ret) ret
} }
lgb.call.return.str <- function(fun_name, ...) { lgb.call.return.str <- function(fun_name, ...) {
buf_len <- as.integer(1024 * 1024) buf_len <- as.integer(1024 * 1024)
act_len <- as.integer(0) act_len <- 0L
buf <- raw(buf_len) buf <- raw(buf_len)
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len) buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
if (act_len > buf_len) { if (act_len > buf_len) {
...@@ -59,12 +51,11 @@ lgb.call.return.str <- function(fun_name, ...) { ...@@ -59,12 +51,11 @@ lgb.call.return.str <- function(fun_name, ...) {
buf <- raw(buf_len) buf <- raw(buf_len)
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len) buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
} }
return(lgb.encode.char(buf, act_len)) lgb.encode.char(buf, act_len)
} }
lgb.params2str <- function(params, ...) { lgb.params2str <- function(params, ...) {
if (typeof(params) != "list") if (!is.list(params)) { stop("params must be a list") }
stop("params must be a list")
names(params) <- gsub("\\.", "_", names(params)) names(params) <- gsub("\\.", "_", names(params))
# merge parameters from the params and the dots-expansion # merge parameters from the params and the dots-expansion
dot_params <- list(...) dot_params <- list(...)
...@@ -72,29 +63,29 @@ lgb.params2str <- function(params, ...) { ...@@ -72,29 +63,29 @@ lgb.params2str <- function(params, ...) {
if (length(intersect(names(params), if (length(intersect(names(params),
names(dot_params))) > 0) names(dot_params))) > 0)
stop( stop(
"Same parameters in 'params' and in the call are not allowed. Please check your 'params' list." "Same parameters in ", sQuote("params"), " and in the call are not allowed. Please check your ", sQuote("params"), " list"
) )
params <- c(params, dot_params) params <- c(params, dot_params)
ret <- list() ret <- list()
for (key in names(params)) { for (key in names(params)) {
# join multi value first # join multi value first
val <- paste0(params[[key]], collapse = ",") val <- paste0(params[[key]], collapse = ",")
if(nchar(val) <= 0) next if (nchar(val) <= 0) next
# join key value # join key value
pair <- paste0(c(key, val), collapse = "=") pair <- paste0(c(key, val), collapse = "=")
ret <- c(ret, pair) ret <- c(ret, pair)
} }
if (length(ret) == 0) { if (length(ret) == 0) {
return(lgb.c_str("")) lgb.c_str("")
} else{ } else {
return(lgb.c_str(paste0(ret, collapse = " "))) lgb.c_str(paste0(ret, collapse = " "))
} }
} }
lgb.c_str <- function(x) { lgb.c_str <- function(x) {
ret <- charToRaw(as.character(x)) ret <- charToRaw(as.character(x))
ret <- c(ret, as.raw(0)) ret <- c(ret, as.raw(0))
return(ret) ret
} }
lgb.check.r6.class <- function(object, name) { lgb.check.r6.class <- function(object, name) {
...@@ -104,54 +95,47 @@ lgb.check.r6.class <- function(object, name) { ...@@ -104,54 +95,47 @@ lgb.check.r6.class <- function(object, name) {
if (!(name %in% class(object))) { if (!(name %in% class(object))) {
return(FALSE) return(FALSE)
} }
return(TRUE) TRUE
} }
lgb.check.params <- function(params){ lgb.check.params <- function(params) {
# To-do # To-do
return(params) params
} }
lgb.check.obj <- function(params, obj) { lgb.check.obj <- function(params, obj) {
if(!is.null(obj)){ OBJECTIVES <- c("regression", "binary", "multiclass", "lambdarank")
params$objective <- obj if (!is.null(obj)) { params$objective <- obj }
} if (is.character(params$objective)) {
if(is.character(params$objective)){ if (!(params$objective %in% OBJECTIVES)) {
if(!(params$objective %in% c("regression", "binary", "multiclass", "lambdarank"))){ stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
stop("lgb.check.obj: objective name error should be (regression, binary, multiclass, lambdarank)")
} }
} else if(typeof(params$objective) != "closure"){ } else if (!is.function(params$objective)) {
stop("lgb.check.obj: objective should be character or function") stop("lgb.check.obj: objective should be a character or a function")
} }
return(params) params
} }
lgb.check.eval <- function(params, eval) { lgb.check.eval <- function(params, eval) {
if(is.null(params$metric)){ if (is.null(params$metric)) { params$metric <- list() }
params$metric <- list() if (!is.null(eval)) {
}
if(!is.null(eval)){
# append metric # append metric
if(is.character(eval) || is.list(eval)){ if (is.character(eval) || is.list(eval)) {
params$metric <- append(params$metric, eval) params$metric <- append(params$metric, eval)
} }
} }
if (typeof(eval) != "closure"){ if (!is.function(eval)) {
if(is.null(params$metric) | length(params$metric) == 0) { if (length(params$metric) == 0) {
# add default metric # add default metric
if(is.character(params$objective)){ params$metric <- switch(
if(params$objective == "regression"){ params$objective,
params$metric <- "l2" regression = "l2",
} else if(params$objective == "binary"){ binary = "binary_logloss",
params$metric <- "binary_logloss" multiclass = "multi_logloss",
} else if(params$objective == "multiclass"){ lambdarank = "ndcg",
params$metric <- "multi_logloss" stop("lgb.check.eval: No default metric available for objective ", sQuote(params$objective))
} else if(params$objective == "lambdarank"){ )
params$metric <- "ndcg"
}
}
} }
} }
return(params) params
} }
...@@ -4,5 +4,3 @@ LightGBM R examples ...@@ -4,5 +4,3 @@ LightGBM R examples
* [Boosting from existing prediction](boost_from_prediction.R) * [Boosting from existing prediction](boost_from_prediction.R)
* [Early Stopping](early_stopping.R) * [Early Stopping](early_stopping.R)
* [Cross Validation](cross_validation.R) * [Cross Validation](cross_validation.R)
...@@ -90,5 +90,3 @@ label = getinfo(dtest, "label") ...@@ -90,5 +90,3 @@ label = getinfo(dtest, "label")
pred <- predict(bst, test$data) pred <- predict(bst, test$data)
err <- as.numeric(sum(as.integer(pred > 0.5) != label))/length(label) err <- as.numeric(sum(as.integer(pred > 0.5) != label))/length(label)
print(paste("test-error=", err)) print(paste("test-error=", err))
...@@ -42,4 +42,3 @@ evalerror <- function(preds, dtrain) { ...@@ -42,4 +42,3 @@ evalerror <- function(preds, dtrain) {
# train with customized objective # train with customized objective
lgb.cv(params = param, data = dtrain, nrounds = nround, obj=logregobj, eval=evalerror, nfold = 5) lgb.cv(params = param, data = dtrain, nrounds = nround, obj=logregobj, eval=evalerror, nfold = 5)
...@@ -36,4 +36,3 @@ print ('start training with early Stopping setting') ...@@ -36,4 +36,3 @@ print ('start training with early Stopping setting')
bst <- lgb.train(param, dtrain, num_round, valids, bst <- lgb.train(param, dtrain, num_round, valids,
objective = logregobj, eval = evalerror, objective = logregobj, eval = evalerror,
early_stopping_round = 3) early_stopping_round = 3)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
% Please edit documentation in R/lgb.Dataset.R % Please edit documentation in R/lgb.Dataset.R
\name{dim.lgb.Dataset} \name{dim.lgb.Dataset}
\alias{dim.lgb.Dataset} \alias{dim.lgb.Dataset}
\title{Dimensions of lgb.Dataset} \title{Dimensions of an lgb.Dataset}
\usage{ \usage{
\method{dim}{lgb.Dataset}(x, ...) \method{dim}{lgb.Dataset}(x, ...)
} }
...@@ -15,23 +15,21 @@ ...@@ -15,23 +15,21 @@
a vector of numbers of rows and of columns a vector of numbers of rows and of columns
} }
\description{ \description{
Dimensions of lgb.Dataset Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
} }
\details{ \details{
Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also
be directly used with an \code{lgb.Dataset} object. be directly used with an \code{lgb.Dataset} object.
} }
\examples{ \examples{
data(agaricus.train, package='lightgbm') dontrun{
train <- agaricus.train data(agaricus.train, package='lightgbm')
dtrain <- lgb.Dataset(train$data, label=train$label) train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label=train$label)
stopifnot(nrow(dtrain) == nrow(train$data))
stopifnot(ncol(dtrain) == ncol(train$data))
stopifnot(all(dim(dtrain) == dim(train$data)))
stopifnot(nrow(dtrain) == nrow(train$data))
stopifnot(ncol(dtrain) == ncol(train$data))
stopifnot(all(dim(dtrain) == dim(train$data)))
}
} }
...@@ -16,25 +16,23 @@ ...@@ -16,25 +16,23 @@
and the second one is column names} and the second one is column names}
} }
\description{ \description{
Handling of column names of \code{lgb.Dataset} Only column names are supported for \code{lgb.Dataset}, thus setting of
row names would have no effect and returned row names would be NULL.
} }
\details{ \details{
Only column names are supported for \code{lgb.Dataset}, thus setting of
row names would have no effect and returnten row names would be NULL.
Generic \code{dimnames} methods are used by \code{colnames}. Generic \code{dimnames} methods are used by \code{colnames}.
Since row names are irrelevant, it is recommended to use \code{colnames} directly. Since row names are irrelevant, it is recommended to use \code{colnames} directly.
} }
\examples{ \examples{
data(agaricus.train, package='lightgbm') dontrun{
train <- agaricus.train data(agaricus.train, package='lightgbm')
dtrain <- lgb.Dataset(train$data, label=train$label) train <- agaricus.train
lgb.Dataset.construct(dtrain) dtrain <- lgb.Dataset(train$data, label=train$label)
dimnames(dtrain) lgb.Dataset.construct(dtrain)
colnames(dtrain) dimnames(dtrain)
colnames(dtrain) <- make.names(1:ncol(train$data)) colnames(dtrain)
print(dtrain, verbose=TRUE) colnames(dtrain) <- make.names(1:ncol(train$data))
print(dtrain, verbose=TRUE)
}
} }
...@@ -33,14 +33,16 @@ The \code{name} field can be one of the following: ...@@ -33,14 +33,16 @@ The \code{name} field can be one of the following:
} }
} }
\examples{ \examples{
data(agaricus.train, package='lightgbm') \dontrun{
train <- agaricus.train data(agaricus.train, package='lightgbm')
dtrain <- lgb.Dataset(train$data, label=train$label) train <- agaricus.train
lgb.Dataset.construct(dtrain) dtrain <- lgb.Dataset(train$data, label=train$label)
labels <- getinfo(dtrain, 'label') lgb.Dataset.construct(dtrain)
setinfo(dtrain, 'label', 1-labels) labels <- getinfo(dtrain, 'label')
setinfo(dtrain, 'label', 1-labels)
labels2 <- getinfo(dtrain, 'label') labels2 <- getinfo(dtrain, 'label')
stopifnot(all(labels2 == 1-labels)) stopifnot(all(labels2 == 1-labels))
}
} }
...@@ -28,18 +28,17 @@ lgb.Dataset(data, params = list(), reference = NULL, colnames = NULL, ...@@ -28,18 +28,17 @@ lgb.Dataset(data, params = list(), reference = NULL, colnames = NULL,
constructed dataset constructed dataset
} }
\description{ \description{
Contruct lgb.Dataset object
}
\details{
Contruct lgb.Dataset object from dense matrix, sparse matrix Contruct lgb.Dataset object from dense matrix, sparse matrix
or local file (that was created previously by saving an \code{lgb.Dataset}). or local file (that was created previously by saving an \code{lgb.Dataset}).
} }
\examples{ \examples{
data(agaricus.train, package='lightgbm') \dontrun{
train <- agaricus.train data(agaricus.train, package='lightgbm')
dtrain <- lgb.Dataset(train$data, label=train$label) train <- agaricus.train
lgb.Dataset.save(dtrain, 'lgb.Dataset.data') dtrain <- lgb.Dataset(train$data, label=train$label)
dtrain <- lgb.Dataset('lgb.Dataset.data') lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
lgb.Dataset.construct(dtrain) dtrain <- lgb.Dataset('lgb.Dataset.data')
lgb.Dataset.construct(dtrain)
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment