Commit 2035c54b authored by Kirill Sevastyanenko's avatar Kirill Sevastyanenko Committed by Guolin Ke
Browse files

Stylistic changes r package (#184)

* src & callbacks

* lgb.Booster and utils

* cv

* wip lgb.Dataset

* lgb.Dataset

* lgb.Predictor

* lgb.train

* typos

* add flags to facilitate macosx compilation

* fix basic_string template error with clang

* most unfortunate mode of development

* fixup tests

* last test

* roxygen

* roxygen v5.x.x
parent c2ba086c
...@@ -363,5 +363,8 @@ ENV/ ...@@ -363,5 +363,8 @@ ENV/
# Rope project settings # Rope project settings
.ropeproject .ropeproject
# R testing artefact
lightgbm.model
# macOS # macOS
.DS_Store .DS_Store
CB_ENV <- R6Class( CB_ENV <- R6Class(
"lgb.cb_env", "lgb.cb_env",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
model=NULL, model = NULL,
iteration=NULL, iteration = NULL,
begin_iteration=NULL, begin_iteration = NULL,
end_iteration=NULL, end_iteration = NULL,
eval_list=list(), eval_list = list(),
eval_err_list=list(), eval_err_list = list(),
best_iter=-1, best_iter = -1,
met_early_stop=FALSE met_early_stop = FALSE
) )
) )
cb.reset.parameters <- function(new_params) { cb.reset.parameters <- function(new_params) {
if (typeof(new_params) != "list") if (!is.list(new_params)) { stop(sQuote("new_params"), " must be a list") }
stop("'new_params' must be a list") pnames <- gsub("\\.", "_", names(new_params))
pnames <- gsub("\\.", "_", names(new_params))
nrounds <- NULL nrounds <- NULL
# run some checks in the begining # run some checks in the begining
init <- function(env) { init <- function(env) {
nrounds <<- env$end_iteration - env$begin_iteration + 1 nrounds <<- env$end_iteration - env$begin_iteration + 1
if (is.null(env$model)) if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
stop("Env should has 'model'")
# Some parameters are not allowed to be changed, # Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos # since changing them would simply wreck some chaos
not_allowed <- pnames %in% not_allowed <- c("num_class", "metric", "boosting_type")
c('num_class', 'metric', 'boosting_type') if (any(pnames %in% not_allowed)) {
if (any(not_allowed)) stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting")
stop('Parameters ', paste(pnames[not_allowed]), " cannot be changed during boosting.") }
for (n in pnames) { for (n in pnames) {
p <- new_params[[n]] p <- new_params[[n]]
if (is.function(p)) { if (is.function(p)) {
if (length(formals(p)) != 2) if (length(formals(p)) != 2)
stop("Parameter '", n, "' is a function but not of two arguments") stop("Parameter ", sQuote(n), " is a function but not of two arguments")
} else if (is.numeric(p) || is.character(p)) { } else if (is.numeric(p) || is.character(p)) {
if (length(p) != nrounds) if (length(p) != nrounds)
stop("Length of '", n, "' has to be equal to 'nrounds'") stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds"))
} else { } else {
stop("Parameter '", n, "' is not a function or a vector") stop("Parameter ", sQuote(n), " is not a function or a vector")
} }
} }
} }
callback <- function(env) { callback <- function(env) {
if (is.null(nrounds)) if (is.null(nrounds)) { init(env) }
init(env)
i <- env$iteration - env$begin_iteration i <- env$iteration - env$begin_iteration
pars <- lapply(new_params, function(p) { pars <- lapply(new_params, function(p) {
if (is.function(p)) if (is.function(p)) { return(p(i, nrounds)) }
return(p(i, nrounds))
p[i] p[i]
}) })
# to-do check pars # to-do check pars
if (!is.null(env$model)) { if (!is.null(env$model)) { env$model$reset_parameter(pars) }
env$model$reset_parameter(pars)
}
} }
attr(callback, 'call') <- match.call()
attr(callback, 'call') <- match.call()
attr(callback, 'is_pre_iteration') <- TRUE attr(callback, 'is_pre_iteration') <- TRUE
attr(callback, 'name') <- 'cb.reset.parameters' attr(callback, 'name') <- 'cb.reset.parameters'
return(callback) callback
} }
# Format the evaluation metric string # Format the evaluation metric string
format.eval.string <- function(eval_res, eval_err=NULL) { format.eval.string <- function(eval_res, eval_err = NULL) {
if (is.null(eval_res)) if (is.null(eval_res) || length(eval_res) == 0) { stop('no evaluation results') }
stop('no evaluation results')
if (length(eval_res) == 0)
stop('no evaluation results')
if (!is.null(eval_err)) { if (!is.null(eval_err)) {
res <- sprintf('%s\'s %s:%g+%g', eval_res$data_name, eval_res$name, eval_res$value, eval_err) sprintf('%s\'s %s:%g+%g', eval_res$data_name, eval_res$name, eval_res$value, eval_err)
} else { } else {
res <- sprintf('%s\'s %s:%g', eval_res$data_name, eval_res$name, eval_res$value) sprintf('%s\'s %s:%g', eval_res$data_name, eval_res$name, eval_res$value)
} }
return(res)
} }
merge.eval.string <- function(env){ merge.eval.string <- function(env) {
if(length(env$eval_list) <= 0){ if (length(env$eval_list) <= 0) { return("") }
return("") msg <- list(sprintf('[%d]:', env$iteration))
}
msg <- list(sprintf('[%d]:',env$iteration))
is_eval_err <- FALSE is_eval_err <- FALSE
if(length(env$eval_err_list) > 0){ if (length(env$eval_err_list) > 0) { is_eval_err <- TRUE }
is_eval_err <- TRUE for (j in seq_along(env$eval_list)) {
}
for(j in 1:length(env$eval_list)) {
eval_err <- NULL eval_err <- NULL
if(is_eval_err){ if (is_eval_err) { eval_err <- env$eval_err_list[[j]] }
eval_err <- env$eval_err_list[[j]] msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err))
}
msg <- c(msg, format.eval.string(env$eval_list[[j]],eval_err))
} }
return(paste0(msg, collapse='\t')) paste0(msg, collapse='\t')
} }
cb.print.evaluation <- function(period=1){ cb.print.evaluation <- function(period = 1){
callback <- function(env){ callback <- function(env) {
if(period > 0){ if (period > 0) {
i <- env$iteration i <- env$iteration
if( (i - 1) %% period == 0 if ( (i - 1) %% period == 0
| i == env$begin_iteration | i == env$begin_iteration
| i == env$end_iteration ){ | i == env$end_iteration ) {
cat(merge.eval.string(env), "\n") cat(merge.eval.string(env), "\n")
} }
} }
} }
attr(callback, 'call') <- match.call() attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.print.evaluation' attr(callback, 'name') <- 'cb.print.evaluation'
return(callback) callback
} }
cb.record.evaluation <- function() { cb.record.evaluation <- function() {
callback <- function(env){ callback <- function(env) {
if(length(env$eval_list) <= 0) return() if (length(env$eval_list) <= 0) { return() }
is_eval_err <- FALSE is_eval_err <- FALSE
if(length(env$eval_err_list) > 0){ if (length(env$eval_err_list) > 0) { is_eval_err <- TRUE }
is_eval_err <- TRUE if (length(env$model$record_evals) == 0) {
} for (j in seq_along(env$eval_list)) {
if(length(env$model$record_evals) == 0){
for(j in 1:length(env$eval_list)) {
data_name <- env$eval_list[[j]]$data_name data_name <- env$eval_list[[j]]$data_name
name <- env$eval_list[[j]]$name name <- env$eval_list[[j]]$name
env$model$record_evals$start_iter <- env$begin_iteration env$model$record_evals$start_iter <- env$begin_iteration
if(is.null(env$model$record_evals[[data_name]])){ if (is.null(env$model$record_evals[[data_name]])) {
env$model$record_evals[[data_name]] <- list() env$model$record_evals[[data_name]] <- list()
} }
env$model$record_evals[[data_name]][[name]] <- list() env$model$record_evals[[data_name]][[name]] <- list()
env$model$record_evals[[data_name]][[name]]$eval <- list() env$model$record_evals[[data_name]][[name]]$eval <- list()
env$model$record_evals[[data_name]][[name]]$eval_err <- list() env$model$record_evals[[data_name]][[name]]$eval_err <- list()
} }
} }
for(j in 1:length(env$eval_list)) { for (j in seq_along(env$eval_list)) {
eval_res <- env$eval_list[[j]] eval_res <- env$eval_list[[j]]
eval_err <- NULL eval_err <- NULL
if(is_eval_err){ if (is_eval_err) { eval_err <- env$eval_err_list[[j]] }
eval_err <- env$eval_err_list[[j]]
}
data_name <- eval_res$data_name data_name <- eval_res$data_name
name <- eval_res$name name <- eval_res$name
env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value) env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value)
env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err) env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err)
} }
} }
attr(callback, 'call') <- match.call() attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.record.evaluation' attr(callback, 'name') <- 'cb.record.evaluation'
return(callback) callback
} }
cb.early.stop <- function(stopping_rounds, verbose=TRUE) { cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# state variables # state variables
factor_to_bigger_better <- NULL factor_to_bigger_better <- NULL
best_iter <- NULL best_iter <- NULL
best_score <- NULL best_score <- NULL
best_msg <- NULL best_msg <- NULL
eval_len <- NULL eval_len <- NULL
init <- function(env) { init <- function(env) {
eval_len <<- length(env$eval_list) eval_len <<- length(env$eval_list)
if (eval_len == 0) if (eval_len == 0) {
stop("For early stopping, valids must have at least one element") stop("For early stopping, valids must have at least one element")
}
if (verbose)
cat("Will train until hasn't improved in ", if (isTRUE(verbose)) {
stopping_rounds, " rounds.\n\n", sep = '') cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = '')
}
factor_to_bigger_better <<- rep(1.0, eval_len) factor_to_bigger_better <<- rep(1.0, eval_len)
best_iter <<- rep(-1, eval_len) best_iter <<- rep(-1, eval_len)
best_score <<- rep(-Inf, eval_len) best_score <<- rep(-Inf, eval_len)
best_msg <<- list() best_msg <<- list()
for(i in 1:eval_len){ for (i in seq_len(eval_len)) {
best_msg <<- c(best_msg, "") best_msg <<- c(best_msg, "")
if(!env$eval_list[[i]]$higher_better){ if (!env$eval_list[[i]]$higher_better) {
factor_to_bigger_better[i] <<- -1.0 factor_to_bigger_better[i] <<- -1.0
} }
} }
} }
callback <- function(env, finalize = FALSE) { callback <- function(env, finalize = FALSE) {
if (is.null(eval_len)) if (is.null(eval_len)) { init(env) }
init(env)
cur_iter <- env$iteration cur_iter <- env$iteration
for(i in 1:eval_len){ for (i in seq_len(eval_len)) {
score <- env$eval_list[[i]]$value * factor_to_bigger_better[i] score <- env$eval_list[[i]]$value * factor_to_bigger_better[i]
if(score > best_score[i]){ if (score > best_score[i]) {
best_score[i] <<- score best_score[i] <<- score
best_iter[i] <<- cur_iter best_iter[i] <<- cur_iter
if(verbose){ if (verbose) {
best_msg[[i]] <<- as.character(merge.eval.string(env)) best_msg[[i]] <<- as.character(merge.eval.string(env))
} }
} else { } else {
if(cur_iter - best_iter[i] >= stopping_rounds){ if (cur_iter - best_iter[i] >= stopping_rounds) {
if(!is.null(env$model)){ if (!is.null(env$model)) { env$model$best_iter <- best_iter[i] }
env$model$best_iter <- best_iter[i] if (isTRUE(verbose)) {
cat("Early stopping, best iteration is:", "\n")
cat(best_msg[[i]], "\n")
} }
if(verbose){ env$best_iter <- best_iter[i]
cat('Early stopping, best iteration is:',"\n")
cat(best_msg[[i]],"\n")
}
env$best_iter <- best_iter[i]
env$met_early_stop <- TRUE env$met_early_stop <- TRUE
} }
} }
...@@ -212,13 +189,11 @@ cb.early.stop <- function(stopping_rounds, verbose=TRUE) { ...@@ -212,13 +189,11 @@ cb.early.stop <- function(stopping_rounds, verbose=TRUE) {
} }
attr(callback, 'call') <- match.call() attr(callback, 'call') <- match.call()
attr(callback, 'name') <- 'cb.early.stop' attr(callback, 'name') <- 'cb.early.stop'
return(callback) callback
} }
# Extract callback names from the list of callbacks # Extract callback names from the list of callbacks
callback.names <- function(cb_list) { callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) }
unlist(lapply(cb_list, function(x) attr(x, 'name')))
}
add.cb <- function(cb_list, cb) { add.cb <- function(cb_list, cb) {
cb_list <- c(cb_list, cb) cb_list <- c(cb_list, cb)
...@@ -226,7 +201,7 @@ add.cb <- function(cb_list, cb) { ...@@ -226,7 +201,7 @@ add.cb <- function(cb_list, cb) {
if ('cb.early.stop' %in% names(cb_list)) { if ('cb.early.stop' %in% names(cb_list)) {
cb_list <- c(cb_list, cb_list['cb.early.stop']) cb_list <- c(cb_list, cb_list['cb.early.stop'])
# this removes only the first one # this removes only the first one
cb_list['cb.early.stop'] <- NULL cb_list['cb.early.stop'] <- NULL
} }
cb_list cb_list
} }
...@@ -235,7 +210,7 @@ categorize.callbacks <- function(cb_list) { ...@@ -235,7 +210,7 @@ categorize.callbacks <- function(cb_list) {
list( list(
pre_iter = Filter(function(x) { pre_iter = Filter(function(x) {
pre <- attr(x, 'is_pre_iteration') pre <- attr(x, 'is_pre_iteration')
!is.null(pre) && pre !is.null(pre) && pre
}, cb_list), }, cb_list),
post_iter = Filter(function(x) { post_iter = Filter(function(x) {
pre <- attr(x, 'is_pre_iteration') pre <- attr(x, 'is_pre_iteration')
......
Booster <- R6Class( Booster <- R6Class(
"lgb.Booster", "lgb.Booster",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
best_iter = -1, best_iter = -1,
record_evals = list(), record_evals = list(),
finalize = function() { finalize = function() {
if(!lgb.is.null.handle(private$handle)){ if (!lgb.is.null.handle(private$handle)) {
print("free booster handle") cat("freeing booster handle\n")
lgb.call("LGBM_BoosterFree_R", ret=NULL, private$handle) lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
initialize = function(params = list(), initialize = function(params = list(),
train_set = NULL, train_set = NULL,
modelfile = NULL, modelfile = NULL,
...) { ...) {
params <- append(params, list(...)) params <- append(params, list(...))
params_str <- lgb.params2str(params) params_str <- lgb.params2str(params)
handle <- lgb.new.handle() handle <- lgb.new.handle()
if (!is.null(train_set)) { if (!is.null(train_set)) {
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
stop("lgb.Booster: Only can use lgb.Dataset as training data") stop("lgb.Booster: Can only use lgb.Dataset as training data")
} }
handle <- handle <-
lgb.call("LGBM_BoosterCreate_R", ret=handle, train_set$.__enclos_env__$private$get_handle(), params_str) lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str)
private$train_set <- train_set
private$num_dataset <- 1 private$train_set <- train_set
private$num_dataset <- 1
private$init_predictor <- train_set$.__enclos_env__$private$predictor private$init_predictor <- train_set$.__enclos_env__$private$predictor
if (!is.null(private$init_predictor)) { if (!is.null(private$init_predictor)) {
lgb.call("LGBM_BoosterMerge_R", ret=NULL, lgb.call("LGBM_BoosterMerge_R", ret = NULL,
handle, handle,
private$init_predictor$.__enclos_env__$private$handle) private$init_predictor$.__enclos_env__$private$handle)
} }
private$is_predicted_cur_iter <- private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
c(private$is_predicted_cur_iter, FALSE)
} else if (!is.null(modelfile)) { } else if (!is.null(modelfile)) {
if (!is.character(modelfile)) { if (!is.character(modelfile)) {
stop("lgb.Booster: Only can use string as model file path") stop("lgb.Booster: Can only use a string as model file path")
} }
handle <- handle <-
lgb.call("LGBM_BoosterCreateFromModelfile_R", lgb.call("LGBM_BoosterCreateFromModelfile_R",
ret=handle, ret = handle,
lgb.c_str(modelfile)) lgb.c_str(modelfile))
} else { } else {
stop( stop(
"lgb.Booster: Need at least one training dataset or model file to create booster instance" "lgb.Booster: Need at least either training dataset or model file to create booster instance"
) )
} }
class(handle) <- "lgb.Booster.handle" class(handle) <- "lgb.Booster.handle"
private$handle <- handle private$handle <- handle
private$num_class <- as.integer(1) private$num_class <- as.integer(1)
private$num_class <- private$num_class <-
lgb.call("LGBM_BoosterGetNumClasses_R", ret=private$num_class, private$handle) lgb.call("LGBM_BoosterGetNumClasses_R", ret = private$num_class, private$handle)
}, },
set_train_data_name = function(name) { set_train_data_name = function(name) {
private$name_train_set <- name private$name_train_set <- name
return(self) self
}, },
add_valid = function(data, name) { add_valid = function(data, name) {
if (!lgb.check.r6.class(data, "lgb.Dataset")) { if (!lgb.check.r6.class(data, "lgb.Dataset")) {
stop("lgb.Booster.add_valid: Only can use lgb.Dataset as validation data") stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data")
} }
if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) { if (!identical(data$.__enclos_env__$private$predictor, private$init_predictor)) {
stop( stop(
"lgb.Booster.add_valid: Add validation data failed, you should use same predictor for these data" "lgb.Booster.add_valid: Failed to add validation data; you should use the same predictor for these data"
) )
} }
if(!is.character(name)){ if (!is.character(name)) {
stop("only can use character as data name") stop("lgb.Booster.add_valid: Can only use characters as data name")
} }
lgb.call("LGBM_BoosterAddValidData_R", ret=NULL, private$handle, data$.__enclos_env__$private$get_handle()) lgb.call("LGBM_BoosterAddValidData_R", ret = NULL, private$handle, data$.__enclos_env__$private$get_handle())
private$valid_sets <- c(private$valid_sets, data) private$valid_sets <- c(private$valid_sets, data)
private$name_valid_sets <- c(private$name_valid_sets, name) private$name_valid_sets <- c(private$name_valid_sets, name)
private$num_dataset <- private$num_dataset + 1 private$num_dataset <- private$num_dataset + 1
private$is_predicted_cur_iter <- private$is_predicted_cur_iter <-
c(private$is_predicted_cur_iter, FALSE) c(private$is_predicted_cur_iter, FALSE)
return(self)
self
}, },
reset_parameter = function(params, ...) { reset_parameter = function(params, ...) {
params <- append(params, list(...)) params <- append(params, list(...))
params_str <- algb.params2str(params) params_str <- algb.params2str(params)
lgb.call("LGBM_BoosterResetParameter_R", ret=NULL, lgb.call("LGBM_BoosterResetParameter_R", ret = NULL,
private$handle, private$handle,
params_str) params_str)
return(self) self
}, },
update = function(train_set = NULL, fobj = NULL) { update = function(train_set = NULL, fobj = NULL) {
if (!is.null(train_set)) { if (!is.null(train_set)) {
...@@ -92,57 +94,51 @@ Booster <- R6Class( ...@@ -92,57 +94,51 @@ Booster <- R6Class(
} }
if (!identical(train_set$predictor, private$init_predictor)) { if (!identical(train_set$predictor, private$init_predictor)) {
stop( stop(
"lgb.Booster.update: Change train_set failed, you should use same predictor for these data" "lgb.Booster.update: Change train_set failed, you should use the same predictor for these data"
) )
} }
lgb.call("LGBM_BoosterResetTrainingData_R", ret=NULL, lgb.call("LGBM_BoosterResetTrainingData_R", ret = NULL,
private$handle, private$handle,
train_set$.__enclos_env__$private$get_handle()) train_set$.__enclos_env__$private$get_handle())
private$train_set = train_set private$train_set = train_set
} }
if (is.null(fobj)) { if (is.null(fobj)) {
ret <- ret <- lgb.call("LGBM_BoosterUpdateOneIter_R", ret = NULL, private$handle)
lgb.call("LGBM_BoosterUpdateOneIter_R", ret=NULL, private$handle)
} else { } else {
if (typeof(fobj) != 'closure') { if (!is.function(fobj)) { stop("lgb.Booster.update: fobj should be a function") }
stop("lgb.Booster.update: fobj should be a function")
}
gpair <- fobj(private$inner_predict(1), private$train_set) gpair <- fobj(private$inner_predict(1), private$train_set)
ret <- ret <- lgb.call(
lgb.call( "LGBM_BoosterUpdateOneIterCustom_R", ret = NULL,
"LGBM_BoosterUpdateOneIterCustom_R", ret=NULL,
private$handle, private$handle,
gpair$grad, gpair$grad,
gpair$hess, gpair$hess,
length(gpair$grad) length(gpair$grad)
) )
} }
for (i in 1:length(private$is_predicted_cur_iter)) { for (i in seq_along(private$is_predicted_cur_iter)) {
private$is_predicted_cur_iter[[i]] <- FALSE private$is_predicted_cur_iter[[i]] <- FALSE
} }
return(ret) ret
}, },
rollback_one_iter = function() { rollback_one_iter = function() {
lgb.call("LGBM_BoosterRollbackOneIter_R", ret=NULL, private$handle) lgb.call("LGBM_BoosterRollbackOneIter_R", ret = NULL, private$handle)
for (i in 1:length(private$is_predicted_cur_iter)) { for (i in seq_along(private$is_predicted_cur_iter)) {
private$is_predicted_cur_iter[[i]] <- FALSE private$is_predicted_cur_iter[[i]] <- FALSE
} }
return(self) self
}, },
current_iter = function() { current_iter = function() {
cur_iter <- as.integer(0) cur_iter <- as.integer(0)
return(lgb.call("LGBM_BoosterGetCurrentIteration_R", ret=cur_iter, private$handle)) lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
}, },
eval = function(data, name, feval = NULL) { eval = function(data, name, feval = NULL) {
if (!lgb.check.r6.class(data, "lgb.Dataset")) { if (!lgb.check.r6.class(data, "lgb.Dataset")) {
stop("lgb.Booster.eval: only can use lgb.Dataset to eval") stop("lgb.Booster.eval: Can only use lgb.Dataset to eval")
} }
data_idx <- 0 data_idx <- 0
if (identical(data, private$train_set)) { if (identical(data, private$train_set)) { data_idx <- 1 } else {
data_idx <- 1 if (length(private$valid_sets) > 0) {
} else { for (i in seq_along(private$valid_sets)) {
if(length(private$valid_sets) > 0){
for (i in 1:length(private$valid_sets)) {
if (identical(data, private$valid_sets[[i]])) { if (identical(data, private$valid_sets[[i]])) {
data_idx <- i + 1 data_idx <- i + 1
break break
...@@ -154,24 +150,21 @@ Booster <- R6Class( ...@@ -154,24 +150,21 @@ Booster <- R6Class(
self$add_valid(data, name) self$add_valid(data, name)
data_idx <- private$num_dataset data_idx <- private$num_dataset
} }
return(private$inner_eval(name, data_idx, feval)) private$inner_eval(name, data_idx, feval)
}, },
eval_train = function(feval = NULL) { eval_train = function(feval = NULL) {
return(private$inner_eval(private$name_train_set, 1, feval)) private$inner_eval(private$name_train_set, 1, feval)
}, },
eval_valid = function(feval = NULL) { eval_valid = function(feval = NULL) {
ret = list() ret = list()
if(length(private$valid_sets) <= 0) return(ret) if (length(private$valid_sets) <= 0) { return(ret) }
for (i in 1:length(private$valid_sets)) { for (i in seq_along(private$valid_sets)) {
ret <- ret <- append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
append(ret, private$inner_eval(private$name_valid_sets[[i]], i + 1, feval))
} }
return(ret) ret
}, },
save_model = function(filename, num_iteration = NULL) { save_model = function(filename, num_iteration = NULL) {
if (is.null(num_iteration)) { if (is.null(num_iteration)) { num_iteration <- self$best_iter }
num_iteration <- self$best_iter
}
lgb.call( lgb.call(
"LGBM_BoosterSaveModel_R", "LGBM_BoosterSaveModel_R",
ret = NULL, ret = NULL,
...@@ -179,97 +172,82 @@ Booster <- R6Class( ...@@ -179,97 +172,82 @@ Booster <- R6Class(
as.integer(num_iteration), as.integer(num_iteration),
lgb.c_str(filename) lgb.c_str(filename)
) )
return(self) self
}, },
dump_model = function(num_iteration = NULL) { dump_model = function(num_iteration = NULL) {
if (is.null(num_iteration)) { if (is.null(num_iteration)) { num_iteration <- self$best_iter }
num_iteration <- self$best_iter lgb.call.return.str(
} "LGBM_BoosterDumpModel_R",
return( private$handle,
lgb.call.return.str( as.integer(num_iteration)
"LGBM_BoosterDumpModel_R",
private$handle,
as.integer(num_iteration)
)
) )
}, },
predict = function(data, predict = function(data,
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE) {
if (is.null(num_iteration)) { if (is.null(num_iteration)) { num_iteration <- self$best_iter }
num_iteration <- self$best_iter
}
predictor <- Predictor$new(private$handle) predictor <- Predictor$new(private$handle)
return(predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)) predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)
}, },
to_predictor = function() { to_predictor = function() { Predictor$new(private$handle) }
Predictor$new(private$handle)
}
), ),
private = list( private = list(
handle = NULL, handle = NULL,
train_set = NULL, train_set = NULL,
name_train_set = "training", name_train_set = "training",
valid_sets = list(), valid_sets = list(),
name_valid_sets = list(), name_valid_sets = list(),
predict_buffer = list(), predict_buffer = list(),
is_predicted_cur_iter = list(), is_predicted_cur_iter = list(),
num_class = 1, num_class = 1,
num_dataset = 0, num_dataset = 0,
init_predictor = NULL, init_predictor = NULL,
eval_names = NULL, eval_names = NULL,
higher_better_inner_eval = NULL, higher_better_inner_eval = NULL,
inner_predict = function(idx) { inner_predict = function(idx) {
data_name <- private$name_train_set data_name <- private$name_train_set
if(idx > 1){ if (idx > 1) { data_name <- private$name_valid_sets[[idx - 1]] }
data_name <- private$name_valid_sets[[idx - 1]]
}
if (idx > private$num_dataset) { if (idx > private$num_dataset) {
stop("data_idx should not be greater than num_dataset") stop("data_idx should not be greater than num_dataset")
} }
if (is.null(private$predict_buffer[[data_name]])) { if (is.null(private$predict_buffer[[data_name]])) {
npred <- as.integer(0) npred <- as.integer(0)
npred <- npred <- lgb.call("LGBM_BoosterGetNumPredict_R",
lgb.call("LGBM_BoosterGetNumPredict_R",
ret = npred, ret = npred,
private$handle, private$handle,
as.integer(idx - 1)) as.integer(idx - 1))
private$predict_buffer[[data_name]] <- rep(0.0, npred) private$predict_buffer[[data_name]] <- rep(0.0, npred)
} }
if (!private$is_predicted_cur_iter[[idx]]) { if (!private$is_predicted_cur_iter[[idx]]) {
private$predict_buffer[[data_name]] <- private$predict_buffer[[data_name]] <- lgb.call(
lgb.call(
"LGBM_BoosterGetPredict_R", "LGBM_BoosterGetPredict_R",
ret=private$predict_buffer[[data_name]], ret = private$predict_buffer[[data_name]],
private$handle, private$handle,
as.integer(idx - 1) as.integer(idx - 1)
) )
private$is_predicted_cur_iter[[idx]] <- TRUE private$is_predicted_cur_iter[[idx]] <- TRUE
} }
return(private$predict_buffer[[data_name]]) private$predict_buffer[[data_name]]
}, },
get_eval_info = function() { get_eval_info = function() {
if (is.null(private$eval_names)) { if (is.null(private$eval_names)) {
names <- names <- lgb.call.return.str("LGBM_BoosterGetEvalNames_R", private$handle)
lgb.call.return.str("LGBM_BoosterGetEvalNames_R", private$handle) if (nchar(names) > 0) {
if(nchar(names) > 0){
names <- strsplit(names, "\t")[[1]] names <- strsplit(names, "\t")[[1]]
private$eval_names <- names private$eval_names <- names
private$higher_better_inner_eval <- private$higher_better_inner_eval <- rep(FALSE, length(names))
rep(FALSE, length(names)) for (i in seq_along(names)) {
for (i in 1:length(names)) {
if (startsWith(names[i], "auc") | if (startsWith(names[i], "auc") |
startsWith(names[i], "ndcg")) { startsWith(names[i], "ndcg")) {
private$higher_better_inner_eval[i] <- TRUE private$higher_better_inner_eval[i] <- TRUE
} }
} }
} }
} }
return(private$eval_names) private$eval_names
}, },
inner_eval = function(data_name, data_idx, feval = NULL) { inner_eval = function(data_name, data_idx, feval = NULL) {
if (data_idx > private$num_dataset) { if (data_idx > private$num_dataset) {
...@@ -279,189 +257,176 @@ Booster <- R6Class( ...@@ -279,189 +257,176 @@ Booster <- R6Class(
ret <- list() ret <- list()
if (length(private$eval_names) > 0) { if (length(private$eval_names) > 0) {
tmp_vals <- rep(0.0, length(private$eval_names)) tmp_vals <- rep(0.0, length(private$eval_names))
tmp_vals <- tmp_vals <- lgb.call("LGBM_BoosterGetEval_R", ret = tmp_vals,
lgb.call("LGBM_BoosterGetEval_R", ret=tmp_vals,
private$handle, private$handle,
as.integer(data_idx - 1)) as.integer(data_idx - 1))
for (i in 1:length(private$eval_names)) { for (i in seq_along(private$eval_names)) {
res <- list() res <- list()
res$data_name <- data_name res$data_name <- data_name
res$name <- private$eval_names[i] res$name <- private$eval_names[i]
res$value <- tmp_vals[i] res$value <- tmp_vals[i]
res$higher_better <- private$higher_better_inner_eval[i] res$higher_better <- private$higher_better_inner_eval[i]
ret <- append(ret, list(res)) ret <- append(ret, list(res))
} }
} }
if (!is.null(feval)) { if (!is.null(feval)) {
if (typeof(feval) != 'closure') { if (!is.function(feval)) {
stop("lgb.Booster.eval: feval should be a function") stop("lgb.Booster.eval: feval should be a function")
} }
data <- private$train_set data <- private$train_set
if (data_idx > 1) { if (data_idx > 1) { data <- private$valid_sets[[data_idx - 1]] }
data <- private$valid_sets[[data_idx - 1]] res <- feval(private$inner_predict(data_idx), data)
}
res <- feval(private$inner_predict(data_idx), data)
res$data_name <- data_name res$data_name <- data_name
ret <- append(ret, list(res)) ret <- append(ret, list(res))
} }
return(ret) ret
} }
) )
) )
# internal helper method
lgb.is.Booster <- function(x){
if(lgb.check.r6.class(x, "lgb.Booster")){
return(TRUE)
} else{
return(FALSE)
}
}
#' Predict method for LightGBM model #' Predict method for LightGBM model
#' #'
#' Predicted values based on class \code{lgb.Booster} #' Predicted values based on class \code{lgb.Booster}
#' #'
#' @param object Object of class \code{lgb.Booster} #' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename #' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration #' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param rawscore whether the prediction should be returned in the for of original untransformed #' @param rawscore whether the prediction should be returned in the for of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for #' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} for
#' logistic regression would result in predictions for log-odds instead of probabilities. #' logistic regression would result in predictions for log-odds instead of probabilities.
#' @param predleaf whether predict leaf index instead. #' @param predleaf whether predict leaf index instead.
#' @param header only used for prediction for text file. True if text file has header #' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several #' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case. #' prediction outputs per case.
#' @return #' @return
#' For regression or binary classification, it returns a vector of length \code{nrows(data)}. #' For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or #' For multiclass classification, either a \code{num_class * nrows(data)} vector or
#' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on #' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#' the \code{reshape} value. #' the \code{reshape} value.
#' #'
#' When \code{predleaf = TRUE}, the output is a matrix object with the #' When \code{predleaf = TRUE}, the output is a matrix object with the
#' number of columns corresponding to the number of trees. #' number of columns corresponding to the number of trees.
#' @examples #' @examples
#' library(lightgbm) #' \dontrun{
#' data(agaricus.train, package='lightgbm') #' library(lightgbm)
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' data(agaricus.test, package='lightgbm') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' test <- agaricus.test #' data(agaricus.test, package='lightgbm')
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) #' test <- agaricus.test
#' params <- list(objective="regression", metric="l2") #' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' valids <- list(test=dtest) #' params <- list(objective="regression", metric="l2")
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' valids <- list(test=dtest)
#' preds <- predict(model, test$data) #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' #' preds <- predict(model, test$data)
#' }
#' @rdname predict.lgb.Booster #' @rdname predict.lgb.Booster
#' @export #' @export
predict.lgb.Booster <- function(object, predict.lgb.Booster <- function(object, data,
data,
num_iteration = NULL, num_iteration = NULL,
rawscore = FALSE, rawscore = FALSE,
predleaf = FALSE, predleaf = FALSE,
header = FALSE, header = FALSE,
reshape = FALSE) { reshape = FALSE) {
if(!lgb.is.Booster(object)){ if (!lgb.is.Booster(object)) {
stop("predict.lgb.Booster: should input lgb.Booster object") stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
} }
object$predict(data, num_iteration, rawscore, predleaf, header, reshape) object$predict(data, num_iteration, rawscore, predleaf, header, reshape)
} }
#' Load LightGBM model #' Load LightGBM model
#' #'
#' Load LightGBM model from saved model file #' Load LightGBM model from saved model file
#' #'
#' @param filename path of model file #' @param filename path of model file
#' #'
#' @return booster #' @return booster
#' @examples #' @examples
#' library(lightgbm) #' \dontrun{
#' data(agaricus.train, package='lightgbm') #' library(lightgbm)
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' data(agaricus.test, package='lightgbm') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' test <- agaricus.test #' data(agaricus.test, package='lightgbm')
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) #' test <- agaricus.test
#' params <- list(objective="regression", metric="l2") #' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' valids <- list(test=dtest) #' params <- list(objective="regression", metric="l2")
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' valids <- list(test=dtest)
#' lgb.save(model, "model.txt") #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' load_booster <- lgb.load("model.txt") #' lgb.save(model, "model.txt")
#' @rdname lgb.load #' load_booster <- lgb.load("model.txt")
#' }
#' @rdname lgb.load
#' @export #' @export
lgb.load <- function(filename){ lgb.load <- function(filename){
if(!is.character(filename)){ if (!is.character(filename)) { stop("lgb.load: filename should be character") }
stop("lgb.load: filename should be character") Booster$new(modelfile = filename)
}
Booster$new(modelfile=filename)
} }
#' Save LightGBM model #' Save LightGBM model
#' #'
#' Save LightGBM model #' Save LightGBM model
#' #'
#' @param booster Object of class \code{lgb.Booster} #' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename #' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration #' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' #'
#' @return booster #' @return booster
#' @examples #' @examples
#' library(lightgbm) #' \dontrun{
#' data(agaricus.train, package='lightgbm') #' library(lightgbm)
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' data(agaricus.test, package='lightgbm') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' test <- agaricus.test #' data(agaricus.test, package='lightgbm')
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) #' test <- agaricus.test
#' params <- list(objective="regression", metric="l2") #' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' valids <- list(test=dtest) #' params <- list(objective="regression", metric="l2")
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' valids <- list(test=dtest)
#' lgb.save(model, "model.txt") #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' @rdname lgb.save #' lgb.save(model, "model.txt")
#' }
#' @rdname lgb.save
#' @export #' @export
lgb.save <- function(booster, filename, num_iteration=NULL){ lgb.save <- function(booster, filename, num_iteration = NULL){
if(!lgb.is.Booster(booster)){ if (!lgb.is.Booster(booster)) { stop("lgb.save: booster should be an ", sQuote("lgb.Booster")) }
stop("lgb.save: should input lgb.Booster object") if (!is.character(filename)) { stop("lgb.save: filename should be a character") }
}
if(!is.character(filename)){
stop("lgb.save: filename should be character")
}
booster$save_model(filename, num_iteration) booster$save_model(filename, num_iteration)
} }
#' Dump LightGBM model to json #' Dump LightGBM model to json
#' #'
#' Dump LightGBM model to json #' Dump LightGBM model to json
#' #'
#' @param booster Object of class \code{lgb.Booster} #' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration #' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' #'
#' @return json format of model #' @return json format of model
#' @examples #' @examples
#' library(lightgbm) #' \dontrun{
#' data(agaricus.train, package='lightgbm') #' library(lightgbm)
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' data(agaricus.test, package='lightgbm') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' test <- agaricus.test #' data(agaricus.test, package='lightgbm')
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) #' test <- agaricus.test
#' params <- list(objective="regression", metric="l2") #' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' valids <- list(test=dtest) #' params <- list(objective="regression", metric="l2")
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' valids <- list(test=dtest)
#' json_model <- lgb.dump(model) #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' @rdname lgb.dump #' json_model <- lgb.dump(model)
#' }
#' @rdname lgb.dump
#' @export #' @export
lgb.dump <- function(booster, num_iteration=NULL){ lgb.dump <- function(booster, num_iteration = NULL){
if(!lgb.is.Booster(booster)){ if (!lgb.is.Booster(booster)) { stop("lgb.save: booster should be an ", sQuote("lgb.Booster")) }
stop("lgb.dump: should input lgb.Booster object")
}
booster$dump_model(num_iteration) booster$dump_model(num_iteration)
} }
#' Get record evaluation result from booster #' Get record evaluation result from booster
#' #'
#' Get record evaluation result from booster #' Get record evaluation result from booster
#' @param booster Object of class \code{lgb.Booster} #' @param booster Object of class \code{lgb.Booster}
#' @param data_name name of dataset #' @param data_name name of dataset
...@@ -469,32 +434,30 @@ lgb.dump <- function(booster, num_iteration=NULL){ ...@@ -469,32 +434,30 @@ lgb.dump <- function(booster, num_iteration=NULL){
#' @param iters iterations, NULL will return all #' @param iters iterations, NULL will return all
#' @param is_err TRUE will return evaluation error instead #' @param is_err TRUE will return evaluation error instead
#' @return vector of evaluation result #' @return vector of evaluation result
#'
#' @rdname lgb.get.eval.result #' @rdname lgb.get.eval.result
#' @export #' @export
lgb.get.eval.result <- function(booster, data_name, eval_name, iters=NULL, is_err=FALSE){ lgb.get.eval.result <- function(booster, data_name, eval_name, iters = NULL, is_err = FALSE) {
if(!lgb.is.Booster(booster)){ if (!lgb.is.Booster(booster)) {
stop("lgb.get.eval.result: only can use booster to get eval result") stop("lgb.get.eval.result: Can only use ", sQuote("lgb.Booster"), " to get eval result")
} }
if(!is.character(data_name) | !is.character(eval_name)){ if (!is.character(data_name) || !is.character(eval_name)) {
stop("lgb.get.eval.result: data_name and eval_name should be character") stop("lgb.get.eval.result: data_name and eval_name should be characters")
} }
if(is.null(booster$record_evals[[data_name]])){ if (is.null(booster$record_evals[[data_name]])) {
stop("lgb.get.eval.result: wrong data name") stop("lgb.get.eval.result: wrong data name")
} }
if(is.null(booster$record_evals[[data_name]][[eval_name]])){ if (is.null(booster$record_evals[[data_name]][[eval_name]])) {
stop("lgb.get.eval.result: wrong eval name") stop("lgb.get.eval.result: wrong eval name")
} }
result <- booster$record_evals[[data_name]][[eval_name]]$eval result <- booster$record_evals[[data_name]][[eval_name]]$eval
if(is_err){ if (is_err) {
result <- booster$record_evals[[data_name]][[eval_name]]$eval_err result <- booster$record_evals[[data_name]][[eval_name]]$eval_err
} }
if(is.null(iters)){ if (is.null(iters)) {
return(as.numeric(result)) return(as.numeric(result))
} }
iters <- as.integer(iters) iters <- as.integer(iters)
delta <- booster$record_evals$start_iter - 1 delta <- booster$record_evals$start_iter - 1
iters <- iters - delta iters <- iters - delta
return(as.numeric(result[iters])) as.numeric(result[iters])
} }
Dataset <- R6Class( Dataset <- R6Class(
"lgb.Dataset", "lgb.Dataset",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
finalize = function() { finalize = function() {
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
print("free dataset handle") cat("free dataset handle\n")
lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle) lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
initialize = function(data, initialize = function(data,
params = list(), params = list(),
reference = NULL, reference = NULL,
colnames = NULL, colnames = NULL,
categorical_feature = NULL, categorical_feature = NULL,
predictor = NULL, predictor = NULL,
free_raw_data = TRUE, free_raw_data = TRUE,
used_indices = NULL, used_indices = NULL,
info = list(), info = list(),
...) { ...) {
addiction_params <- list(...) additional_params <- list(...)
for (key in names(addiction_params)) { INFO_KEYS <- c('label', 'weight', 'init_score', 'group')
if (key %in% c('label', 'weight', 'init_score', 'group')) { for (key in names(additional_params)) {
info[[key]] <- addiction_params[[key]] if (key %in% INFO_KEYS) {
info[[key]] <- additional_params[[key]]
} else { } else {
params[[key]] <- addiction_params[[key]] params[[key]] <- additional_params[[key]]
} }
} }
if (!is.null(reference)) { if (!is.null(reference)) {
if (!lgb.check.r6.class(reference, "lgb.Dataset")) { if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
stop("lgb.Dataset: Only can use lgb.Dataset as reference") stop("lgb.Dataset: Can only use ", sQuote("lgb.Dataset"), " as reference")
} }
} }
if (!is.null(predictor)) { if (!is.null(predictor)) {
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) { if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
stop("lgb.Dataset: Only can use lgb.Predictor as predictor") stop("lgb.Dataset: Only can use ", sQuote("lgb.Predictor"), " as predictor")
} }
} }
private$raw_data <- data private$raw_data <- data
private$params <- params private$params <- params
private$reference <- reference private$reference <- reference
private$colnames <- colnames private$colnames <- colnames
private$categorical_feature <- categorical_feature private$categorical_feature <- categorical_feature
private$predictor <- predictor private$predictor <- predictor
private$free_raw_data <- free_raw_data private$free_raw_data <- free_raw_data
private$used_indices <- used_indices private$used_indices <- used_indices
private$info <- info private$info <- info
}, },
create_valid = function(data, info = list(), ...) { create_valid = function(data, info = list(), ...) {
ret <- Dataset$new( ret <- Dataset$new(
...@@ -61,7 +62,7 @@ Dataset <- R6Class( ...@@ -61,7 +62,7 @@ Dataset <- R6Class(
info, info,
... ...
) )
return(ret) ret
}, },
construct = function() { construct = function() {
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
...@@ -69,29 +70,28 @@ Dataset <- R6Class( ...@@ -69,29 +70,28 @@ Dataset <- R6Class(
} }
# Get feature names # Get feature names
cnames <- NULL cnames <- NULL
if (is.matrix(private$raw_data) | if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
class(private$raw_data) == "dgCMatrix") {
cnames <- colnames(private$raw_data) cnames <- colnames(private$raw_data)
} }
# set feature names if not exist # set feature names if not exist
if (is.null(private$colnames) & !is.null(cnames)) { if (is.null(private$colnames) && !is.null(cnames)) {
private$colnames <- as.character(cnames) private$colnames <- as.character(cnames)
} }
# Get categorical feature index # Get categorical feature index
if (!is.null(private$categorical_feature)) { if (!is.null(private$categorical_feature)) {
fname_dict <- list() fname_dict <- list()
if (!is.null(private$colnames)) { if (!is.null(private$colnames)) {
fname_dict <- fname_dict <- `names<-`(
as.list(setNames(0:(length( list((seq_along(private$colnames) - 1)),
private$colnames private$colnames
) - 1), private$colnames)) )
} }
cate_indices <- list() cate_indices <- list()
for (key in private$categorical_feature) { for (key in private$categorical_feature) {
if (is.character(key)) { if (is.character(key)) {
idx <- fname_dict[[key]] idx <- fname_dict[[key]]
if (is.null(idx)) { if (is.null(idx)) {
stop(paste("lgb.self.get.handle: cannot find feature name ", key)) stop("lgb.self.get.handle: cannot find feature name ", sQuote(key))
} }
cate_indices <- c(cate_indices, idx) cate_indices <- c(cate_indices, idx)
} else { } else {
...@@ -104,10 +104,10 @@ Dataset <- R6Class( ...@@ -104,10 +104,10 @@ Dataset <- R6Class(
} }
# Check has header or not # Check has header or not
has_header <- FALSE has_header <- FALSE
if (!is.null(private$params$has_header) | if (!is.null(private$params$has_header) ||
!is.null(private$params$header)) { !is.null(private$params$header)) {
if (tolower(as.character(private$params$has_header)) == "true" if (tolower(as.character(private$params$has_header)) == "true"
| ||
tolower(as.character(private$params$header)) == "true") { tolower(as.character(private$params$header)) == "true") {
has_header <- TRUE has_header <- TRUE
} }
...@@ -122,9 +122,8 @@ Dataset <- R6Class( ...@@ -122,9 +122,8 @@ Dataset <- R6Class(
handle <- lgb.new.handle() handle <- lgb.new.handle()
# not subset # not subset
if (is.null(private$used_indices)) { if (is.null(private$used_indices)) {
if (typeof(private$raw_data) == "character") { if (is.character(private$raw_data)) {
handle <- handle <- lgb.call(
lgb.call(
"LGBM_DatasetCreateFromFile_R", "LGBM_DatasetCreateFromFile_R",
ret = handle, ret = handle,
lgb.c_str(private$raw_data), lgb.c_str(private$raw_data),
...@@ -132,8 +131,7 @@ Dataset <- R6Class( ...@@ -132,8 +131,7 @@ Dataset <- R6Class(
ref_handle ref_handle
) )
} else if (is.matrix(private$raw_data)) { } else if (is.matrix(private$raw_data)) {
handle <- handle <- lgb.call(
lgb.call(
"LGBM_DatasetCreateFromMat_R", "LGBM_DatasetCreateFromMat_R",
ret = handle, ret = handle,
private$raw_data, private$raw_data,
...@@ -142,7 +140,7 @@ Dataset <- R6Class( ...@@ -142,7 +140,7 @@ Dataset <- R6Class(
params_str, params_str,
ref_handle ref_handle
) )
} else if (class(private$raw_data) == "dgCMatrix") { } else if (is(private$raw_data, "dgCMatrix")) {
handle <- lgb.call( handle <- lgb.call(
"LGBM_DatasetCreateFromCSC_R", "LGBM_DatasetCreateFromCSC_R",
ret = handle, ret = handle,
...@@ -156,18 +154,16 @@ Dataset <- R6Class( ...@@ -156,18 +154,16 @@ Dataset <- R6Class(
ref_handle ref_handle
) )
} else { } else {
stop(paste( stop(
"lgb.Dataset.construct: does not support to construct from ", "lgb.Dataset.construct: does not support constructing from ", sQuote(class(private$raw_data))
typeof(private$raw_data) )
))
} }
} else { } else {
# construct subset # construct subset
if (is.null(private$reference)) { if (is.null(private$reference)) {
stop("lgb.Dataset.construct: reference cannot be NULL if construct subset") stop("lgb.Dataset.construct: reference cannot be NULL for constructing data subset")
} }
handle <- handle <- lgb.call(
lgb.call(
"LGBM_DatasetGetSubset_R", "LGBM_DatasetGetSubset_R",
ret = handle, ret = handle,
ref_handle, ref_handle,
...@@ -179,27 +175,20 @@ Dataset <- R6Class( ...@@ -179,27 +175,20 @@ Dataset <- R6Class(
class(handle) <- "lgb.Dataset.handle" class(handle) <- "lgb.Dataset.handle"
private$handle <- handle private$handle <- handle
# set feature names # set feature names
if (!is.null(private$colnames)) { if (!is.null(private$colnames)) { self$set_colnames(private$colnames) }
self$set_colnames(private$colnames)
}
# load init score # load init score
if (!is.null(private$predictor) & if (!is.null(private$predictor) &&
is.null(private$used_indices)) { is.null(private$used_indices)) {
init_score <- init_score <- private$predictor$predict(private$raw_data, rawscore = TRUE, reshape = TRUE)
private$predictor$predict(private$raw_data, # do not need to transpose, for is col_marjor
rawscore = TRUE,
reshape = TRUE)
# not need to transpose, for is col_marjor
init_score <- as.vector(init_score) init_score <- as.vector(init_score)
private$info$init_score <- init_score private$info$init_score <- init_score
} }
if (private$free_raw_data & !is.character(private$raw_data)) { if (isTRUE(private$free_raw_data)) { private$raw_data <- NULL }
private$raw_data <- NULL
}
if (length(private$info) > 0) { if (length(private$info) > 0) {
# set infos # set infos
for (i in 1:length(private$info)) { for (i in seq_along(private$info)) {
p <- private$info[i] p <- private$info[i]
self$setinfo(names(p), p[[1]]) self$setinfo(names(p), p[[1]])
} }
...@@ -207,45 +196,42 @@ Dataset <- R6Class( ...@@ -207,45 +196,42 @@ Dataset <- R6Class(
if (is.null(self$getinfo("label"))) { if (is.null(self$getinfo("label"))) {
stop("lgb.Dataset.construct: label should be set") stop("lgb.Dataset.construct: label should be set")
} }
return(self) self
}, },
dim = function() { dim = function() {
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
num_row <- as.integer(0) num_row <- as.integer(0)
num_col <- as.integer(0) num_col <- as.integer(0)
return(c( c(
lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle), lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle),
lgb.call("LGBM_DatasetGetNumFeature_R", ret = num_col, private$handle) lgb.call("LGBM_DatasetGetNumFeature_R", ret = num_col, private$handle)
)) )
} else if (is.matrix(private$raw_data) | } else if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
class(private$raw_data) == "dgCMatrix") { dim(private$raw_data)
return(dim(private$raw_data))
} else { } else {
stop( stop(
"dim: cannot get Dimensions before dataset constructed, please call lgb.Dataset.construct explicit" "dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly"
) )
} }
}, },
get_colnames = function() { get_colnames = function() {
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
cnames <- lgb.call.return.str("LGBM_DatasetGetFeatureNames_R", cnames <- lgb.call.return.str("LGBM_DatasetGetFeatureNames_R", private$handle)
private$handle) private$colnames <- as.character(base::strsplit(cnames, "\t")[[1]])
private$colnames <- as.character(strsplit(cnames, "\t")[[1]]) private$colnames
return(private$colnames) } else if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
} else if (is.matrix(private$raw_data) | colnames(private$raw_data)
class(private$raw_data) == "dgCMatrix") {
return(colnames(private$raw_data))
} else { } else {
stop( stop(
"colnames: cannot get colnames before dataset constructed, please call lgb.Dataset.construct explicit" "dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly"
) )
} }
}, },
set_colnames = function(colnames) { set_colnames = function(colnames) {
if(is.null(colnames)) return(self) if (is.null(colnames)) { return(self) }
colnames <- as.character(colnames) colnames <- as.character(colnames)
if(length(colnames) == 0) return(self) if (length(colnames) == 0) { return(self) }
private$colnames <- colnames private$colnames <- colnames
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
merged_name <- paste0(as.list(private$colnames), collapse = "\t") merged_name <- paste0(as.list(private$colnames), collapse = "\t")
...@@ -254,58 +240,47 @@ Dataset <- R6Class( ...@@ -254,58 +240,47 @@ Dataset <- R6Class(
private$handle, private$handle,
lgb.c_str(merged_name)) lgb.c_str(merged_name))
} }
return(self) self
}, },
getinfo = function(name) { getinfo = function(name) {
if (typeof(name) != "character" || INFONAMES <- c("label", "weight", "init_score", "group")
length(name) != 1 || if (!is.character(name) ||
!name %in% c('label', 'weight', 'init_score', 'group')) { length(name) != 1 ||
!name %in% INFONAMES) {
stop( stop(
"getinfo: name must one of the following\n", "getinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", ")
" 'label', 'weight', 'init_score', 'group'"
) )
} }
if (is.null(private$info[[name]]) & if (is.null(private$info[[name]]) && !lgb.is.null.handle(private$handle)) {
!lgb.is.null.handle(private$handle)) {
info_len <- as.integer(0) info_len <- as.integer(0)
info_len <- info_len <- lgb.call("LGBM_DatasetGetFieldSize_R",
lgb.call("LGBM_DatasetGetFieldSize_R", ret = info_len,
ret = info_len, private$handle,
private$handle, lgb.c_str(name))
lgb.c_str(name))
if (info_len > 0) { if (info_len > 0) {
ret <- NULL ret <- NULL
if (name == "group") { ret <- if (name == "group") { integer(info_len) } else { rep(0.0, info_len) }
ret <- integer(info_len) ret <- lgb.call("LGBM_DatasetGetField_R",
} else { ret = ret,
ret <- rep(0.0, info_len) private$handle,
} lgb.c_str(name))
ret <-
lgb.call("LGBM_DatasetGetField_R",
ret = ret,
private$handle,
lgb.c_str(name))
private$info[[name]] <- ret private$info[[name]] <- ret
} }
} }
return(private$info[[name]]) private$info[[name]]
}, },
setinfo = function(name, info) { setinfo = function(name, info) {
if (typeof(name) != "character" || INFONAMES <- c("label", "weight", "init_score", "group")
length(name) != 1 || if (!is.character(name) ||
!name %in% c('label', 'weight', 'init_score', 'group')) { length(name) != 1 ||
!name %in% INFONAMES) {
stop( stop(
"setinfo: name must one of the following\n", "setinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", ")
" 'label', 'weight', 'init_score', 'group'"
) )
} }
if (name == "group") { info <- if (name == "group") { as.integer(info) } else { as.numeric(info) }
info <- as.integer(info)
} else {
info <- as.numeric(info)
}
private$info[[name]] <- info private$info[[name]] <- info
if (!lgb.is.null.handle(private$handle) & !is.null(info)) { if (!lgb.is.null.handle(private$handle) && !is.null(info)) {
if (length(info) > 0) { if (length(info) > 0) {
lgb.call( lgb.call(
"LGBM_DatasetSetField_R", "LGBM_DatasetSetField_R",
...@@ -317,10 +292,10 @@ Dataset <- R6Class( ...@@ -317,10 +292,10 @@ Dataset <- R6Class(
) )
} }
} }
return(self) self
}, },
slice = function(idxset, ...) { slice = function(idxset, ...) {
ret <- Dataset$new( Dataset$new(
NULL, NULL,
private$params, private$params,
self, self,
...@@ -332,46 +307,42 @@ Dataset <- R6Class( ...@@ -332,46 +307,42 @@ Dataset <- R6Class(
NULL, NULL,
... ...
) )
return(ret)
}, },
update_params = function(params){ update_params = function(params) {
private$params <- modifyList(private$params, params) private$params <- modifyList(private$params, params)
self
}, },
set_categorical_feature = function(categorical_feature) { set_categorical_feature = function(categorical_feature) {
if (identical(private$categorical_feature, categorical_feature)) { if (identical(private$categorical_feature, categorical_feature)) { return(self) }
return(self)
}
if (is.null(private$raw_data)) { if (is.null(private$raw_data)) {
stop( stop(
"set_categorical_feature: cannot set categorical feature after free raw data, "set_categorical_feature: cannot set categorical feature after freeing raw data,
please set free_raw_data=FALSE when construct lgb.Dataset" please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset"
) )
} }
private$categorical_feature <- categorical_feature private$categorical_feature <- categorical_feature
self$finalize() self$finalize()
return(self) self
}, },
set_reference = function(reference) { set_reference = function(reference) {
self$set_categorical_feature(reference$.__enclos_env__$private$categorical_feature) self$set_categorical_feature(reference$.__enclos_env__$private$categorical_feature)
self$set_colnames(reference$get_colnames()) self$set_colnames(reference$get_colnames())
private$set_predictor(reference$.__enclos_env__$private$predictor) private$set_predictor(reference$.__enclos_env__$private$predictor)
if (identical(private$reference, reference)) { if (identical(private$reference, reference)) { return(self) }
return(self)
}
if (is.null(private$raw_data)) { if (is.null(private$raw_data)) {
stop( stop(
"set_reference: cannot set reference after free raw data, "set_reference: cannot set reference after freeing raw data,
please set free_raw_data=FALSE when construct lgb.Dataset" please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset"
) )
} }
if (!is.null(reference)) { if (!is.null(reference)) {
if (!lgb.check.r6.class(reference, "lgb.Dataset")) { if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
stop("set_reference: Only can use lgb.Dataset as reference") stop("set_reference: Can only use lgb.Dataset as a reference")
} }
} }
private$reference <- reference private$reference <- reference
self$finalize() self$finalize()
return(self) self
}, },
save_binary = function(fname) { save_binary = function(fname) {
self$construct() self$construct()
...@@ -379,50 +350,44 @@ Dataset <- R6Class( ...@@ -379,50 +350,44 @@ Dataset <- R6Class(
ret = NULL, ret = NULL,
private$handle, private$handle,
lgb.c_str(fname)) lgb.c_str(fname))
return(self) self
} }
), ),
private = list( private = list(
handle = NULL, handle = NULL,
raw_data = NULL, raw_data = NULL,
params = list(), params = list(),
reference = NULL, reference = NULL,
colnames = NULL, colnames = NULL,
categorical_feature = NULL, categorical_feature = NULL,
predictor = NULL, predictor = NULL,
free_raw_data = TRUE, free_raw_data = TRUE,
used_indices = NULL, used_indices = NULL,
info = NULL, info = NULL,
get_handle = function() { get_handle = function() {
if (lgb.is.null.handle(private$handle)) { if (lgb.is.null.handle(private$handle)) { self$construct() }
self$construct() private$handle
}
return(private$handle)
}, },
set_predictor = function(predictor) { set_predictor = function(predictor) {
if (identical(private$predictor, predictor)) { if (identical(private$predictor, predictor)) { return(self) }
return(self)
}
if (is.null(private$raw_data)) { if (is.null(private$raw_data)) {
stop( stop(
"set_predictor: cannot set predictor after free raw data, "set_predictor: cannot set predictor after free raw data,
please set free_raw_data=FALSE when construct lgb.Dataset" please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset"
) )
} }
if (!is.null(predictor)) { if (!is.null(predictor)) {
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) { if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
stop("set_predictor: Only can use lgb.Predictor as predictor") stop("set_predictor: Can only use lgb.Predictor as predictor")
} }
} }
private$predictor <- predictor private$predictor <- predictor
self$finalize() self$finalize()
return(self) self
} }
) )
) )
#' Contruct lgb.Dataset object
#'
#' Contruct lgb.Dataset object #' Contruct lgb.Dataset object
#' #'
#' Contruct lgb.Dataset object from dense matrix, sparse matrix #' Contruct lgb.Dataset object from dense matrix, sparse matrix
...@@ -438,20 +403,22 @@ Dataset <- R6Class( ...@@ -438,20 +403,22 @@ Dataset <- R6Class(
#' @param ... other information to pass to \code{info} or parameters pass to \code{params} #' @param ... other information to pass to \code{info} or parameters pass to \code{params}
#' @return constructed dataset #' @return constructed dataset
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' lgb.Dataset.save(dtrain, 'lgb.Dataset.data') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' dtrain <- lgb.Dataset('lgb.Dataset.data') #' lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
#' lgb.Dataset.construct(dtrain) #' dtrain <- lgb.Dataset('lgb.Dataset.data')
#' lgb.Dataset.construct(dtrain)
#' }
#' @export #' @export
lgb.Dataset <- function(data, lgb.Dataset <- function(data,
params = list(), params = list(),
reference = NULL, reference = NULL,
colnames = NULL, colnames = NULL,
categorical_feature = NULL, categorical_feature = NULL,
free_raw_data = TRUE, free_raw_data = TRUE,
info = list(), info = list(),
...) { ...) {
Dataset$new( Dataset$new(
data, data,
...@@ -467,18 +434,10 @@ lgb.Dataset <- function(data, ...@@ -467,18 +434,10 @@ lgb.Dataset <- function(data,
) )
} }
# internal helper method
lgb.is.Dataset <- function(x){
if(lgb.check.r6.class(x, "lgb.Dataset")){
return(TRUE)
} else{
return(FALSE)
}
}
#' Contruct a validation data #' Contruct validation data
#' #'
#' Contruct a validation data according to training data #' Contruct validation data according to training data
#' #'
#' @param dataset \code{lgb.Dataset} object, training data #' @param dataset \code{lgb.Dataset} object, training data
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename #' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
...@@ -486,42 +445,41 @@ lgb.is.Dataset <- function(x){ ...@@ -486,42 +445,41 @@ lgb.is.Dataset <- function(x){
#' @param ... other information to pass to \code{info}. #' @param ... other information to pass to \code{info}.
#' @return constructed dataset #' @return constructed dataset
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' data(agaricus.test, package='lightgbm') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' test <- agaricus.test #' data(agaricus.test, package='lightgbm')
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) #' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' }
#' @export #' @export
lgb.Dataset.create.valid <- lgb.Dataset.create.valid <- function(dataset, data, info = list(), ...) {
function(dataset, data, info = list(), ...) { if (!lgb.is.Dataset(dataset)) {
if(!lgb.is.Dataset(dataset)) { stop("lgb.Dataset.create.valid: input data should be an lgb.Dataset object")
stop("lgb.Dataset.create.valid: input data should be lgb.Dataset object")
}
return(dataset$create_valid(data, info, ...))
} }
dataset$create_valid(data, info, ...)
}
#' Construct Dataset explicit #' Construct Dataset explicitly
#'
#' Construct Dataset explicit
#' #'
#' @param dataset Object of class \code{lgb.Dataset} #' @param dataset Object of class \code{lgb.Dataset}
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' lgb.Dataset.construct(dtrain) #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' lgb.Dataset.construct(dtrain)
#' }
#' @export #' @export
lgb.Dataset.construct <- function(dataset) { lgb.Dataset.construct <- function(dataset) {
if(!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("lgb.Dataset.construct: input data should be lgb.Dataset object") stop("lgb.Dataset.construct: input data should be an lgb.Dataset object")
} }
return(dataset$construct()) dataset$construct()
} }
#' Dimensions of lgb.Dataset #' Dimensions of an lgb.Dataset
#'
#' Dimensions of lgb.Dataset
#' #'
#' Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}. #' Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
#' @param x Object of class \code{lgb.Dataset} #' @param x Object of class \code{lgb.Dataset}
...@@ -533,29 +491,28 @@ lgb.Dataset.construct <- function(dataset) { ...@@ -533,29 +491,28 @@ lgb.Dataset.construct <- function(dataset) {
#' be directly used with an \code{lgb.Dataset} object. #' be directly used with an \code{lgb.Dataset} object.
#' #'
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' stopifnot(nrow(dtrain) == nrow(train$data)) #'
#' stopifnot(ncol(dtrain) == ncol(train$data)) #' stopifnot(nrow(dtrain) == nrow(train$data))
#' stopifnot(all(dim(dtrain) == dim(train$data))) #' stopifnot(ncol(dtrain) == ncol(train$data))
#' #' stopifnot(all(dim(dtrain) == dim(train$data)))
#' }
#' @rdname dim #' @rdname dim
#' @export #' @export
dim.lgb.Dataset <- function(x, ...) { dim.lgb.Dataset <- function(x, ...) {
if(!lgb.is.Dataset(x)) { if (!lgb.is.Dataset(x)) {
stop("dim.lgb.Dataset: input data should be lgb.Dataset object") stop("dim.lgb.Dataset: input data should be an lgb.Dataset object")
} }
return(x$dim()) x$dim()
} }
#' Handling of column names of \code{lgb.Dataset}
#'
#' Handling of column names of \code{lgb.Dataset} #' Handling of column names of \code{lgb.Dataset}
#' #'
#' Only column names are supported for \code{lgb.Dataset}, thus setting of #' Only column names are supported for \code{lgb.Dataset}, thus setting of
#' row names would have no effect and returnten row names would be NULL. #' row names would have no effect and returned row names would be NULL.
#' #'
#' @param x object of class \code{lgb.Dataset} #' @param x object of class \code{lgb.Dataset}
#' @param value a list of two elements: the first one is ignored #' @param value a list of two elements: the first one is ignored
...@@ -566,48 +523,47 @@ dim.lgb.Dataset <- function(x, ...) { ...@@ -566,48 +523,47 @@ dim.lgb.Dataset <- function(x, ...) {
#' Since row names are irrelevant, it is recommended to use \code{colnames} directly. #' Since row names are irrelevant, it is recommended to use \code{colnames} directly.
#' #'
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' lgb.Dataset.construct(dtrain) #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' dimnames(dtrain) #' lgb.Dataset.construct(dtrain)
#' colnames(dtrain) #' dimnames(dtrain)
#' colnames(dtrain) <- make.names(1:ncol(train$data)) #' colnames(dtrain)
#' print(dtrain, verbose=TRUE) #' colnames(dtrain) <- make.names(1:ncol(train$data))
#' #' print(dtrain, verbose=TRUE)
#' }
#' @rdname dimnames.lgb.Dataset #' @rdname dimnames.lgb.Dataset
#' @export #' @export
dimnames.lgb.Dataset <- function(x) { dimnames.lgb.Dataset <- function(x) {
if(!lgb.is.Dataset(x)) { if (!lgb.is.Dataset(x)) {
stop("dimnames.lgb.Dataset: input data should be lgb.Dataset object") stop("dimnames.lgb.Dataset: input data should be an lgb.Dataset object")
} }
return(list(NULL, x$get_colnames())) list(NULL, x$get_colnames())
} }
#' @rdname dimnames.lgb.Dataset #' @rdname dimnames.lgb.Dataset
#' @export #' @export
`dimnames<-.lgb.Dataset` <- function(x, value) { `dimnames<-.lgb.Dataset` <- function(x, value) {
if (!is.list(value) || length(value) != 2L) if (!is.list(value) || length(value) != 2L)
stop("invalid 'dimnames' given: must be a list of two elements") stop("invalid ", sQuote("value"), " given: must be a list of two elements")
if (!is.null(value[[1L]])) if (!is.null(value[[1L]])) { stop("lgb.Dataset does not have rownames") }
stop("lgb.Dataset does not have rownames")
if (is.null(value[[2]])) { if (is.null(value[[2]])) {
x$set_colnames(NULL) x$set_colnames(NULL)
return(x) return(x)
} }
if (ncol(x) != length(value[[2]])) if (ncol(x) != length(value[[2]]))
stop("can't assign ", stop("can't assign ",
length(value[[2]]), sQuote(length(value[[2]])),
" colnames to a ", " colnames to an lgb.Dataset with ",
ncol(x), sQuote(ncol(x)), " columns")
" column lgb.Dataset")
x$set_colnames(value[[2]]) x$set_colnames(value[[2]])
return(x) x
} }
#' Slice an dataset #' Slice a dataset
#' #'
#' Get a new Dataset containing the specified rows of #' Get a new \code{lgb.Dataset} containing the specified rows of
#' orginal lgb.Dataset object #' orginal lgb.Dataset object
#' #'
#' @param dataset Object of class "lgb.Dataset" #' @param dataset Object of class "lgb.Dataset"
...@@ -616,29 +572,27 @@ dimnames.lgb.Dataset <- function(x) { ...@@ -616,29 +572,27 @@ dimnames.lgb.Dataset <- function(x) {
#' @return constructed sub dataset #' @return constructed sub dataset
#' #'
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' dsub <- slice(dtrain, 1:42)
#' labels1 <- getinfo(dsub, 'label')
#' #'
#' dsub <- slice(dtrain, 1:42)
#' labels1 <- getinfo(dsub, 'label')
#' }
#' @export #' @export
slice <- function(dataset, ...) slice <- function(dataset, ...) { UseMethod("slice") }
UseMethod("slice")
#' @rdname slice #' @rdname slice
#' @export #' @export
slice.lgb.Dataset <- function(dataset, idxset, ...) { slice.lgb.Dataset <- function(dataset, idxset, ...) {
if(!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("slice.lgb.Dataset: input data should be lgb.Dataset object") stop("slice.lgb.Dataset: input dataset should be an lgb.Dataset object")
} }
return(dataset$slice(idxset, ...)) dataset$slice(idxset, ...)
} }
#' Get information of an lgb.Dataset object
#'
#' Get information of an lgb.Dataset object #' Get information of an lgb.Dataset object
#' #'
#' @param dataset Object of class \code{lgb.Dataset} #' @param dataset Object of class \code{lgb.Dataset}
...@@ -657,30 +611,29 @@ slice.lgb.Dataset <- function(dataset, idxset, ...) { ...@@ -657,30 +611,29 @@ slice.lgb.Dataset <- function(dataset, idxset, ...) {
#' } #' }
#' #'
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' lgb.Dataset.construct(dtrain) #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' labels <- getinfo(dtrain, 'label') #' lgb.Dataset.construct(dtrain)
#' setinfo(dtrain, 'label', 1-labels) #' labels <- getinfo(dtrain, 'label')
#' #' setinfo(dtrain, 'label', 1-labels)
#' labels2 <- getinfo(dtrain, 'label') #'
#' stopifnot(all(labels2 == 1-labels)) #' labels2 <- getinfo(dtrain, 'label')
#' stopifnot(all(labels2 == 1-labels))
#' }
#' @export #' @export
getinfo <- function(dataset, ...) getinfo <- function(dataset, ...) { UseMethod("getinfo") }
UseMethod("getinfo")
#' @rdname getinfo #' @rdname getinfo
#' @export #' @export
getinfo.lgb.Dataset <- function(dataset, name, ...) { getinfo.lgb.Dataset <- function(dataset, name, ...) {
if(!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("getinfo.lgb.Dataset: input data should be lgb.Dataset object") stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
} }
return(dataset$getinfo(name)) dataset$getinfo(name)
} }
#' Set information of an lgb.Dataset object
#'
#' Set information of an lgb.Dataset object #' Set information of an lgb.Dataset object
#' #'
#' @param dataset Object of class "lgb.Dataset" #' @param dataset Object of class "lgb.Dataset"
...@@ -700,96 +653,97 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) { ...@@ -700,96 +653,97 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) {
#' } #' }
#' #'
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' lgb.Dataset.construct(dtrain) #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' labels <- getinfo(dtrain, 'label') #' lgb.Dataset.construct(dtrain)
#' setinfo(dtrain, 'label', 1-labels) #' labels <- getinfo(dtrain, 'label')
#' labels2 <- getinfo(dtrain, 'label') #' setinfo(dtrain, 'label', 1-labels)
#' stopifnot(all.equal(labels2, 1-labels)) #' labels2 <- getinfo(dtrain, 'label')
#' stopifnot(all.equal(labels2, 1-labels))
#' }
#' @export #' @export
setinfo <- function(dataset, ...) setinfo <- function(dataset, ...) { UseMethod("setinfo") }
UseMethod("setinfo")
#' @rdname setinfo #' @rdname setinfo
#' @export #' @export
setinfo.lgb.Dataset <- function(dataset, name, info, ...) { setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
if(!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("setinfo.lgb.Dataset: input data should be lgb.Dataset object") stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
} }
return(dataset$setinfo(name, info)) dataset$setinfo(name, info)
} }
#' set categorical feature of \code{lgb.Dataset} #' Set categorical feature of \code{lgb.Dataset}
#'
#' set categorical feature of \code{lgb.Dataset}
#' #'
#' @param dataset object of class \code{lgb.Dataset} #' @param dataset object of class \code{lgb.Dataset}
#' @param categorical_feature categorical features #' @param categorical_feature categorical features
#' @return passed dataset #' @return passed dataset
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' lgb.Dataset.save(dtrain, 'lgb.Dataset.data') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' dtrain <- lgb.Dataset('lgb.Dataset.data') #' lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
#' lgb.Dataset.set.categorical(dtrain, 1:2) #' dtrain <- lgb.Dataset('lgb.Dataset.data')
#' lgb.Dataset.set.categorical(dtrain, 1:2)
#' }
#' @rdname lgb.Dataset.set.categorical #' @rdname lgb.Dataset.set.categorical
#' @export #' @export
lgb.Dataset.set.categorical <- lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
function(dataset, categorical_feature) { if (!lgb.is.Dataset(dataset)) {
if(!lgb.is.Dataset(dataset)) { stop("lgb.Dataset.set.categorical: input dataset should be an lgb.Dataset object")
stop("lgb.Dataset.set.categorical: input data should be lgb.Dataset object")
}
return(dataset$set_categorical_feature(categorical_feature))
} }
dataset$set_categorical_feature(categorical_feature)
}
#' set reference of \code{lgb.Dataset} #' Set reference of \code{lgb.Dataset}
#' #'
#' set reference of \code{lgb.Dataset}. #' If you want to use validation data, you should set reference to training data
#' If you want to use validation data, you should set its reference to training data
#' #'
#' @param dataset object of class \code{lgb.Dataset} #' @param dataset object of class \code{lgb.Dataset}
#' @param reference object of class \code{lgb.Dataset} #' @param reference object of class \code{lgb.Dataset}
#' @return passed dataset #' @return passed dataset
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' data(agaricus.test, package='lightgbm') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' test <- agaricus.test #' data(agaricus.test, package='lightgbm')
#' dtest <- lgb.Dataset(test$data, test=train$label) #' test <- agaricus.test
#' lgb.Dataset.set.reference(dtest, dtrain) #' dtest <- lgb.Dataset(test$data, test=train$label)
#' lgb.Dataset.set.reference(dtest, dtrain)
#' }
#' @rdname lgb.Dataset.set.reference #' @rdname lgb.Dataset.set.reference
#' @export #' @export
lgb.Dataset.set.reference <- function(dataset, reference) { lgb.Dataset.set.reference <- function(dataset, reference) {
if(!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("lgb.Dataset.set.reference: input data should be lgb.Dataset object") stop("lgb.Dataset.set.reference: input dataset should be an lgb.Dataset object")
} }
return(dataset$set_reference(reference)) dataset$set_reference(reference)
} }
#' save \code{lgb.Dataset} to binary file #' Save \code{lgb.Dataset} to a binary file
#' #'
#' save \code{lgb.Dataset} to binary file
#'
#' @param dataset object of class \code{lgb.Dataset} #' @param dataset object of class \code{lgb.Dataset}
#' @param fname object filename of output file #' @param fname object filename of output file
#' @return passed dataset #' @return passed dataset
#' @examples #' @examples
#' data(agaricus.train, package='lightgbm') #' \dontrun{
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' lgb.Dataset.save(dtrain, "data.bin") #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' lgb.Dataset.save(dtrain, "data.bin")
#' }
#' @rdname lgb.Dataset.save #' @rdname lgb.Dataset.save
#' @export #' @export
lgb.Dataset.save <- function(dataset, fname) { lgb.Dataset.save <- function(dataset, fname) {
if(!lgb.is.Dataset(dataset)) { if (!lgb.is.Dataset(dataset)) {
stop("lgb.Dataset.set: input data should be lgb.Dataset object") stop("lgb.Dataset.set: input dataset should be an lgb.Dataset object")
} }
if(!is.character(fname)) { if (!is.character(fname)) {
stop("lgb.Dataset.set: filename should be character type") stop("lgb.Dataset.set: fname should be a character or a file connection")
} }
return(dataset$save_binary(fname)) dataset$save_binary(fname)
} }
Predictor <- R6Class( Predictor <- R6Class(
"lgb.Predictor", "lgb.Predictor",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
finalize = function() { finalize = function() {
if(private$need_free_handle & !lgb.is.null.handle(private$handle)){ if (private$need_free_handle && !lgb.is.null.handle(private$handle)) {
print("free booster handle") cat("free booster handle\n")
lgb.call("LGBM_BoosterFree_R", ret=NULL, private$handle) lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
}, },
initialize = function(modelfile) { initialize = function(modelfile) {
handle <- lgb.new.handle() handle <- lgb.new.handle()
if(typeof(modelfile) == "character") { if (is.character(modelfile)) {
handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret=handle, lgb.c_str(modelfile)) handle <- lgb.call("LGBM_BoosterCreateFromModelfile_R", ret = handle, lgb.c_str(modelfile))
private$need_free_handle = TRUE private$need_free_handle <- TRUE
} else if (class(modelfile) == "lgb.Booster.handle") { } else if (is(modelfile, "lgb.Booster.handle")) {
handle <- modelfile handle <- modelfile
private$need_free_handle = FALSE private$need_free_handle <- FALSE
} else { } else {
stop("lgb.Predictor: modelfile must be either character filename, or lgb.Booster.handle") stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
} }
class(handle) <- "lgb.Booster.handle" class(handle) <- "lgb.Booster.handle"
private$handle <- handle private$handle <- handle
}, },
current_iter = function() { current_iter = function() {
cur_iter <- as.integer(0) cur_iter <- 0L
return(lgb.call("LGBM_BoosterGetCurrentIteration_R", ret=cur_iter, private$handle)) lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
}, },
predict = function(data, predict = function(data, num_iteration = NULL, rawscore = FALSE,
num_iteration = NULL, rawscore = FALSE, predleaf = FALSE, header = FALSE, predleaf = FALSE, header = FALSE, reshape = FALSE) {
reshape = FALSE) {
if (is.null(num_iteration)) {
num_iteration <- -1
}
if (is.null(num_iteration)) { num_iteration <- -1 }
num_row <- 0 num_row <- 0
if (typeof(data) == "character") { if (is.character(data)) {
tmp_filename <- tempfile(pattern = "lightgbm_") tmp_filename <- tempfile(pattern = "lightgbm_")
lgb.call("LGBM_BoosterPredictForFile_R", ret=NULL, private$handle, data, as.integer(header), on.exit(unlink(tmp_filename), add = TRUE)
lgb.call("LGBM_BoosterPredictForFile_R", ret = NULL, private$handle, data,
as.integer(header),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration), as.integer(num_iteration),
lgb.c_str(tmp_filename)) lgb.c_str(tmp_filename))
preds <- read.delim(tmp_filename, header=FALSE, seq="\t") preds <- read.delim(tmp_filename, header = FALSE, seq = "\t")
num_row <- nrow(preds) num_row <- nrow(preds)
preds <- as.vector(t(preds)) preds <- as.vector(t(preds))
# delete temp file
if(file.exists(tmp_filename)) { file.remove(tmp_filename) }
} else { } else {
num_row <- nrow(data) num_row <- nrow(data)
npred <- as.integer(0) npred <- as.integer(0)
npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret=npred, npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret = npred,
private$handle, private$handle,
as.integer(num_row), as.integer(num_row),
as.integer(rawscore), as.integer(rawscore),
...@@ -61,19 +56,19 @@ Predictor <- R6Class( ...@@ -61,19 +56,19 @@ Predictor <- R6Class(
# allocte space for prediction # allocte space for prediction
preds <- rep(0.0, npred) preds <- rep(0.0, npred)
if (is.matrix(data)) { if (is.matrix(data)) {
preds <- lgb.call("LGBM_BoosterPredictForMat_R", ret=preds, preds <- lgb.call("LGBM_BoosterPredictForMat_R", ret = preds,
private$handle, private$handle,
data, data,
as.integer(nrow(data)), as.integer(nrow(data)),
as.integer(ncol(data)), as.integer(ncol(data)),
as.integer(rawscore), as.integer(rawscore),
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration)) as.integer(num_iteration))
} else if (class(data) == "dgCMatrix") { } else if (is(data, "dgCMatrix")) {
preds <- lgb.call("LGBM_BoosterPredictForCSC_R", ret=preds, preds <- lgb.call("LGBM_BoosterPredictForCSC_R", ret = preds,
private$handle, private$handle,
data@p, data@p,
data@i, data@i,
data@x, data@x,
length(data@p), length(data@p),
length(data@x), length(data@x),
...@@ -82,23 +77,17 @@ Predictor <- R6Class( ...@@ -82,23 +77,17 @@ Predictor <- R6Class(
as.integer(predleaf), as.integer(predleaf),
as.integer(num_iteration)) as.integer(num_iteration))
} else { } else {
stop(paste("predict: does not support to predict from ", stop("predict: cannot predict on data of class ", sQuote(class(data)))
typeof(data))) }
}
} }
if (length(preds) %% num_row != 0) { if (length(preds) %% num_row != 0) {
stop("predict: prediction length ", length(preds)," is not multiple of nrows(data) ", num_row) stop("predict: prediction length ", sQuote(length(preds))," is not a multiple of nrows(data): ", sQuote(num_row))
} }
npred_per_case <- length(preds) / num_row npred_per_case <- length(preds) / num_row
if (reshape && npred_per_case > 1) { if (reshape && npred_per_case > 1) { preds <- matrix(preds, ncol = npred_per_case) }
preds <- matrix(preds, ncol = npred_per_case) preds
}
return(preds)
} }
), ),
private = list( private = list( handle = NULL, need_free_handle = FALSE )
handle = NULL,
need_free_handle = FALSE
)
) )
CVBooster <- R6Class( CVBooster <- R6Class(
"lgb.CVBooster", "lgb.CVBooster",
cloneable=FALSE, cloneable = FALSE,
public = list( public = list(
best_iter = -1, best_iter = -1,
record_evals = list(), record_evals = list(),
boosters=list(), boosters = list(),
initialize=function(x){ initialize = function(x) {
self$boosters <- x self$boosters <- x
}, },
reset_parameter=function(new_paramas){ reset_parameter = function(new_params) {
for(x in boosters){ for (x in boosters) { x$reset_parameter(new_params) }
x$reset_parameter(new_paramas) self
}
return(self)
} }
) )
) )
#' Main CV logic for LightGBM #' Main CV logic for LightGBM
#' #'
#' Main CV logic for LightGBM #' Main CV logic for LightGBM
#' #'
#' @param params List of parameters #' @param params List of parameters
#' @param data a \code{lgb.Dataset} object, used for CV #' @param data a \code{lgb.Dataset} object, used for CV
#' @param nrounds number of CV rounds #' @param nrounds number of CV rounds
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples. #' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label vector of response values. Should be provided only when data is an R-matrix. #' @param label vector of response values. Should be provided only when data is an R-matrix.
#' @param weight vector of response values. If not NULL, will set to dataset #' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function #' @param obj objective function, can be character or custom objective function
#' @param eval evaluation function, can be (list of) character or custom eval function #' @param eval evaluation function, can be (list of) character or custom eval function
#' @param verbose verbosity for output #' @param verbose verbosity for output
#' if verbose > 0 , also will record iteration message to booster$record_evals #' if verbose > 0 , also will record iteration message to booster$record_evals
#' @param eval_freq evalutaion output frequence #' @param eval_freq evalutaion output frequence
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation #' @param showsd \code{boolean}, whether to show standard deviation of cross validation
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified #' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
#' by the values of outcome labels. #' by the values of outcome labels.
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds #' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
#' (each element must be a vector of test fold's indices). When folds are supplied, #' (each element must be a vector of test fold's indices). When folds are supplied,
#' the \code{nfold} and \code{stratified} parameters are ignored. #' the \code{nfold} and \code{stratified} parameters are ignored.
#' @param init_model path of model file of \code{lgb.Booster} object, will continue train from this model #' @param init_model path of model file of \code{lgb.Booster} object, will continue train from this model
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset #' @param colnames feature names, if not null, will use this to overwrite the names in dataset
...@@ -53,180 +51,163 @@ CVBooster <- R6Class( ...@@ -53,180 +51,163 @@ CVBooster <- R6Class(
#' @param callbacks list of callback functions #' @param callbacks list of callback functions
#' List of callback functions that are applied at each iteration. #' List of callback functions that are applied at each iteration.
#' @param ... other parameters, see parameters.md for more informations #' @param ... other parameters, see parameters.md for more informations
#' @return a trained booster model \code{lgb.Booster}. #' @return a trained booster model \code{lgb.Booster}.
#' @examples #' @examples
#' library(lightgbm) #' \dontrun{
#' data(agaricus.train, package='lightgbm') #' library(lightgbm)
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' params <- list(objective="regression", metric="l2") #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' model <- lgb.cv(params, dtrain, 10, nfold=5, min_data=1, learning_rate=1, early_stopping_rounds=10) #' params <- list(objective="regression", metric="l2")
#' #' model <- lgb.cv(params, dtrain, 10, nfold=5, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' }
#' @rdname lgb.train #' @rdname lgb.train
#' @export #' @export
lgb.cv <- function(params=list(), data, nrounds=10, nfold=3, lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3,
label = NULL, weight = NULL, label = NULL, weight = NULL,
obj=NULL, eval=NULL, obj = NULL, eval = NULL,
verbose=1, eval_freq=1L, showsd = TRUE, verbose = 1, eval_freq = 1L, showsd = TRUE,
stratified = TRUE, folds = NULL, stratified = TRUE, folds = NULL,
init_model=NULL, init_model = NULL,
colnames=NULL, colnames= NULL,
categorical_feature=NULL, categorical_feature = NULL,
early_stopping_rounds=NULL, early_stopping_rounds = NULL,
callbacks=list(), ...) { callbacks = list(), ...) {
addiction_params <- list(...) addiction_params <- list(...)
params <- append(params, addiction_params) params <- append(params, addiction_params)
params$verbose <- verbose params$verbose <- verbose
params <- lgb.check.obj(params, obj) params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval) params <- lgb.check.eval(params, eval)
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if(typeof(params$objective) == "closure"){ if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
}
if (typeof(eval) == "closure"){
feval <- eval
} }
if (is.function(eval)) { feval <- eval }
lgb.check.params(params) lgb.check.params(params)
predictor <- NULL predictor <- NULL
if(is.character(init_model)){ if (is.character(init_model)) {
predictor <- Predictor$new(init_model) predictor <- Predictor$new(init_model)
} else if(lgb.is.Booster(init_model)) { } else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor() predictor <- init_model$to_predictor()
} }
begin_iteration <- 1 begin_iteration <- 1
if(!is.null(predictor)){ if (!is.null(predictor)) {
begin_iteration <- predictor$current_iter() + 1 begin_iteration <- predictor$current_iter() + 1
} }
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
# check dataset if (!lgb.is.Dataset(data)) {
if(!lgb.is.Dataset(data)){ if (is.null(label)) { stop("Labels must be provided for lgb.cv") }
if(is.null(label)){ data <- lgb.Dataset(data, label = label)
stop("Labels must be provided for lgb.cv")
}
data <- lgb.Dataset(data, label=label)
} }
if(!is.null(weight)) data$set_info("weight", weight) if (!is.null(weight)) { data$set_info("weight", weight) }
data$update_params(params) data$update_params(params)
data$.__enclos_env__$private$set_predictor(predictor) data$.__enclos_env__$private$set_predictor(predictor)
if(!is.null(colnames)){ if (!is.null(colnames)) { data$set_colnames(colnames) }
data$set_colnames(colnames)
}
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
data$construct() data$construct()
# CV folds if (!is.null(folds)) {
if(!is.null(folds)) { if (!is.list(folds) || length(folds) < 2)
if(class(folds) != "list" || length(folds) < 2) stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold")
stop("'folds' must be a list with 2 or more elements that are vectors of indices for each CV-fold")
nfold <- length(folds) nfold <- length(folds)
} else { } else {
if (nfold <= 1) if (nfold <= 1) { stop(sQuote("nfold"), " must be > 1") }
stop("'nfold' must be > 1")
folds <- generate.cv.folds(nfold, nrow(data), stratified, getinfo(data, 'label'), params) folds <- generate.cv.folds(nfold, nrow(data), stratified, getinfo(data, 'label'), params)
} }
# process callbacks if (eval_freq > 0) {
if(eval_freq > 0){
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
if (verbose > 0) { if (verbose > 0) { callbacks <- add.cb(callbacks, cb.record.evaluation()) }
callbacks <- add.cb(callbacks, cb.record.evaluation())
}
# Early stopping callback
if (!is.null(early_stopping_rounds)) { if (!is.null(early_stopping_rounds)) {
if(early_stopping_rounds > 0){ if (early_stopping_rounds > 0) {
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose=verbose)) callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose = verbose))
} }
} }
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
# construct booster # construct booster
bst_folds <- lapply(seq_along(folds), function(k) {
bst_folds <- lapply(1:length(folds), function(k) { dtest <- slice(data, folds[[k]])
dtest <- slice(data, folds[[k]]) dtrain <- slice(data, unlist(folds[-k]))
dtrain <- slice(data, unlist(folds[-k]))
booster <- Booster$new(params, dtrain) booster <- Booster$new(params, dtrain)
booster$add_valid(dtest, "valid") booster$add_valid(dtest, "valid")
list(booster=booster) list(booster = booster)
}) })
cv_booster <- CVBooster$new(bst_folds) cv_booster <- CVBooster$new(bst_folds)
# callback env # callback env
env <- CB_ENV$new()
env <- CB_ENV$new() env$model <- cv_booster
env$model <- cv_booster
env$begin_iteration <- begin_iteration env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
#start training #start training
for(i in begin_iteration:end_iteration){ for (i in seq(from = begin_iteration, to = end_iteration)) {
env$iteration <- i env$iteration <- i
env$eval_list <- list() env$eval_list <- list()
for (f in cb$pre_iter) f(env) for (f in cb$pre_iter) { f(env) }
# update one iter # update one iter
msg <- lapply(cv_booster$boosters, function(fd) { msg <- lapply(cv_booster$boosters, function(fd) {
fd$booster$update(fobj=fobj) fd$booster$update(fobj = fobj)
fd$booster$eval_valid(feval=feval) fd$booster$eval_valid(feval = feval)
}) })
merged_msg <- lgb.merge.cv.result(msg) merged_msg <- lgb.merge.cv.result(msg)
env$eval_list <- merged_msg$eval_list env$eval_list <- merged_msg$eval_list
if(showsd) env$eval_err_list <- merged_msg$eval_err_list if(showsd) { env$eval_err_list <- merged_msg$eval_err_list }
for (f in cb$post_iter) { f(env) }
for (f in cb$post_iter) f(env)
# met early stopping # met early stopping
if(env$met_early_stop) break if (env$met_early_stop) break
} }
return(cv_booster) cv_booster
} }
# Generates random (stratified if needed) CV folds # Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) { generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
# cannot do it for rank # cannot do it for rank
if (exists('objective', where=params) && if (exists('objective', where = params) &&
is.character(params$objective) && is.character(params$objective) &&
params$objective == 'lambdarank') { params$objective == 'lambdarank') {
stop("\n\tAutomatic generation of CV-folds is not implemented for lambdarank!\n", stop("\n\tAutomatic generation of CV-folds is not implemented for lambdarank!\n",
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n") "\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
} }
# shuffle # shuffle
rnd_idx <- sample(1:nrows) rnd_idx <- sample(seq_len(nrows))
if (stratified && if (isTRUE(stratified) &&
length(label) == length(rnd_idx)) { length(label) == length(rnd_idx)) {
y <- label[rnd_idx] y <- label[rnd_idx]
y <- factor(y) y <- factor(y)
folds <- lgb.stratified.folds(y, nfold) folds <- lgb.stratified.folds(y, nfold)
} else { } else {
# make simple non-stratified folds # make simple non-stratified folds
kstep <- length(rnd_idx) %/% nfold kstep <- length(rnd_idx) %/% nfold
folds <- list() folds <- list()
for (i in 1:(nfold - 1)) { for (i in seq_len(nfold - 1)) {
folds[[i]] <- rnd_idx[1:kstep] folds[[i]] <- rnd_idx[seq_len(kstep)]
rnd_idx <- rnd_idx[-(1:kstep)] rnd_idx <- rnd_idx[-(seq_len(kstep))]
} }
folds[[nfold]] <- rnd_idx folds[[nfold]] <- rnd_idx
} }
return(folds) folds
} }
# Creates CV folds stratified by the values of y. # Creates CV folds stratified by the values of y.
# It was borrowed from caret::lgb.stratified.folds and simplified # It was borrowed from caret::lgb.stratified.folds and simplified
# by always returning an unnamed list of fold indices. # by always returning an unnamed list of fold indices.
lgb.stratified.folds <- function(y, k = 10) lgb.stratified.folds <- function(y, k = 10) {
{
if (is.numeric(y)) { if (is.numeric(y)) {
## Group the numeric data based on their magnitudes ## Group the numeric data based on their magnitudes
## and sample within those groups. ## and sample within those groups.
...@@ -236,14 +217,13 @@ lgb.stratified.folds <- function(y, k = 10) ...@@ -236,14 +217,13 @@ lgb.stratified.folds <- function(y, k = 10)
## groups. The number of groups will depend on the ## groups. The number of groups will depend on the
## ratio of the number of folds to the sample size. ## ratio of the number of folds to the sample size.
## At most, we will use quantiles. If the sample ## At most, we will use quantiles. If the sample
## is too small, we just do regular unstratified ## is too small, we just do regular unstratified CV
## CV
cuts <- floor(length(y) / k) cuts <- floor(length(y) / k)
if (cuts < 2) cuts <- 2 if (cuts < 2) { cuts <- 2 }
if (cuts > 5) cuts <- 5 if (cuts > 5) { cuts <- 5 }
y <- cut(y, y <- cut(y,
unique(stats::quantile(y, probs = seq(0, 1, length = cuts))), unique(stats::quantile(y, probs = seq(0, 1, length = cuts))),
include.lowest = TRUE) include.lowest = TRUE)
} }
if (k < length(y)) { if (k < length(y)) {
...@@ -256,48 +236,41 @@ lgb.stratified.folds <- function(y, k = 10) ...@@ -256,48 +236,41 @@ lgb.stratified.folds <- function(y, k = 10)
## For each class, balance the fold allocation as far ## For each class, balance the fold allocation as far
## as possible, then resample the remainder. ## as possible, then resample the remainder.
## The final assignment of folds is also randomized. ## The final assignment of folds is also randomized.
for (i in 1:length(numInClass)) { for (i in seq_along(numInClass)) {
## create a vector of integers from 1:k as many times as possible without ## create a vector of integers from 1:k as many times as possible without
## going over the number of samples in the class. Note that if the number ## going over the number of samples in the class. Note that if the number
## of samples in a class is less than k, nothing is producd here. ## of samples in a class is less than k, nothing is producd here.
seqVector <- rep(1:k, numInClass[i] %/% k) seqVector <- rep(seq_len(k), numInClass[i] %/% k)
## add enough random integers to get length(seqVector) == numInClass[i] ## add enough random integers to get length(seqVector) == numInClass[i]
if (numInClass[i] %% k > 0) seqVector <- c(seqVector, sample(1:k, numInClass[i] %% k)) if (numInClass[i] %% k > 0) {
seqVector <- c(seqVector, sample(seq_len(k), numInClass[i] %% k))
}
## shuffle the integers for fold assignment and assign to this classes's data ## shuffle the integers for fold assignment and assign to this classes's data
foldVector[which(y == dimnames(numInClass)$y[i])] <- sample(seqVector) foldVector[y == dimnames(numInClass)$y[i]] <- sample(seqVector)
} }
} else { } else {
foldVector <- seq(along = y) foldVector <- seq(along = y)
} }
out <- split(seq(along = y), foldVector) out <- split(seq(along = y), foldVector)
names(out) <- NULL `names<-`(out, NULL)
out
} }
lgb.merge.cv.result <- function(msg, showsd=TRUE){ lgb.merge.cv.result <- function(msg, showsd = TRUE){
if(length(msg) == 0){ if (length(msg) == 0) { stop("lgb.cv: size of cv result error") }
stop("lgb.cv: size of cv result error")
}
eval_len <- length(msg[[1]]) eval_len <- length(msg[[1]])
if(eval_len == 0){ if (eval_len == 0) { stop("lgb.cv: should provide at least one metric for CV") }
stop("lgb.cv: should provide at least metric for CV") eval_result <- lapply(seq_len(eval_len), function(j) {
} as.numeric(lapply(seq_along(msg), function(i) { msg[[i]][[j]]$value }))
eval_result <- lapply(1:eval_len, function(j) {
as.numeric(lapply(1:length(msg), function(i){
msg[[i]][[j]]$value
}))
}) })
ret_eval <- msg[[1]] ret_eval <- msg[[1]]
for(j in 1:eval_len){ for (j in seq_len(eval_len)) { ret_eval[[j]]$value <- mean(eval_result[[j]]) }
ret_eval[[j]]$value <- mean(eval_result[[j]])
}
ret_eval_err <- NULL ret_eval_err <- NULL
if(showsd){ if (showsd) {
for(j in 1:eval_len){ for (j in seq_len(eval_len)) {
ret_eval_err <- c( ret_eval_err, sqrt( mean(eval_result[[j]]^2) - mean(eval_result[[j]])^2 )) ret_eval_err <- c( ret_eval_err, sqrt( mean(eval_result[[j]]^2) - mean(eval_result[[j]])^2 ))
} }
ret_eval_err <- as.list(ret_eval_err) ret_eval_err <- as.list(ret_eval_err)
} }
return(list(eval_list=ret_eval, eval_err_list=ret_eval_err)) list(eval_list = ret_eval, eval_err_list = ret_eval_err)
} }
\ No newline at end of file
#' Main training logic for LightGBM #' Main training logic for LightGBM
#' #'
#' Main training logic for LightGBM
#'
#' @param params List of parameters #' @param params List of parameters
#' @param data a \code{lgb.Dataset} object, used for training #' @param data a \code{lgb.Dataset} object, used for training
#' @param nrounds number of training rounds #' @param nrounds number of training rounds
#' @param valids a list of \code{lgb.Dataset} object, used for validation #' @param valids a list of \code{lgb.Dataset} objects, used for validation
#' @param obj objective function, can be character or custom objective function #' @param obj objective function, can be character or custom objective function
#' @param eval evaluation function, can be (list of) character or custom eval function #' @param eval evaluation function, can be (a list of) character or custom eval function
#' @param verbose verbosity for output #' @param verbose verbosity for output
#' if verbose > 0 , also will record iteration message to booster$record_evals #' if \code{verbose > 0}, also will record iteration message to \code{booster$record_evals}
#' @param eval_freq evalutaion output frequence #' @param eval_freq evalutaion output frequency
#' @param init_model path of model file of \code{lgb.Booster} object, will continue train from this model #' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset #' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int #' @param categorical_feature list of str or int
#' type int represents index, #' type int represents index,
...@@ -25,86 +23,86 @@ ...@@ -25,86 +23,86 @@
#' @param callbacks list of callback functions #' @param callbacks list of callback functions
#' List of callback functions that are applied at each iteration. #' List of callback functions that are applied at each iteration.
#' @param ... other parameters, see parameters.md for more informations #' @param ... other parameters, see parameters.md for more informations
#' @return a trained booster model \code{lgb.Booster}. #' @return a trained booster model \code{lgb.Booster}.
#' @examples #' @examples
#' library(lightgbm) #' \dontrun{
#' data(agaricus.train, package='lightgbm') #' library(lightgbm)
#' train <- agaricus.train #' data(agaricus.train, package='lightgbm')
#' dtrain <- lgb.Dataset(train$data, label=train$label) #' train <- agaricus.train
#' data(agaricus.test, package='lightgbm') #' dtrain <- lgb.Dataset(train$data, label=train$label)
#' test <- agaricus.test #' data(agaricus.test, package='lightgbm')
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label) #' test <- agaricus.test
#' params <- list(objective="regression", metric="l2") #' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' valids <- list(test=dtest) #' params <- list(objective="regression", metric="l2")
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10) #' valids <- list(test=dtest)
#' #' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' }
#' @rdname lgb.train #' @rdname lgb.train
#' @export #' @export
lgb.train <- function(params=list(), data, nrounds=10, lgb.train <- function(params = list(), data, nrounds = 10,
valids=list(), valids = list(),
obj=NULL, eval=NULL, obj = NULL,
verbose=1, eval_freq=1L, eval = NULL,
init_model=NULL, verbose = 1,
colnames=NULL, eval_freq = 1L,
categorical_feature=NULL, init_model = NULL,
early_stopping_rounds=NULL, colnames = NULL,
callbacks=list(), ...) { categorical_feature = NULL,
addiction_params <- list(...) early_stopping_rounds = NULL,
params <- append(params, addiction_params) callbacks = list(), ...) {
additional_params <- list(...)
params <- append(params, additional_params)
params$verbose <- verbose params$verbose <- verbose
params <- lgb.check.obj(params, obj) params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval) params <- lgb.check.eval(params, eval)
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if(typeof(params$objective) == "closure"){ if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
}
if (typeof(eval) == "closure"){
feval <- eval
} }
if (is.function(eval)) { feval <- eval }
lgb.check.params(params) lgb.check.params(params)
predictor <- NULL predictor <- NULL
if(is.character(init_model)){ if (is.character(init_model)) {
predictor <- Predictor$new(init_model) predictor <- Predictor$new(init_model)
} else if(lgb.is.Booster(init_model)) { } else if (lgb.is.Booster(init_model)) {
predictor <- init_model$to_predictor() predictor <- init_model$to_predictor()
} }
begin_iteration <- 1 begin_iteration <- 1
if(!is.null(predictor)){ if (!is.null(predictor)) {
begin_iteration <- predictor$current_iter() + 1 begin_iteration <- predictor$current_iter() + 1
} }
end_iteration <- begin_iteration + nrounds - 1 end_iteration <- begin_iteration + nrounds - 1
# check dataset # check dataset
if(!lgb.is.Dataset(data)){ if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object") stop("lgb.train: data only accepts lgb.Dataset object")
} }
if (length(valids) > 0) { if (length(valids) > 0) {
if (typeof(valids) != "list" || if (!is.list(valids) || !all(sapply(valids, lgb.is.Dataset))) {
!all(sapply(valids, lgb.is.Dataset))) stop("lgb.train: valids must be a list of lgb.Dataset elements")
stop("valids must be a list of lgb.Dataset elements") }
evnames <- names(valids) evnames <- names(valids)
if (is.null(evnames) || any(evnames == "")) if (is.null(evnames) || !all(nzchar(evnames))) {
stop("each element of the valids must have a name tag") stop("lgb.train: each element of the valids must have a name tag")
}
} }
data$update_params(params) data$update_params(params)
data$.__enclos_env__$private$set_predictor(predictor) data$.__enclos_env__$private$set_predictor(predictor)
if(!is.null(colnames)){ if (!is.null(colnames)) { data$set_colnames(colnames) }
data$set_colnames(colnames)
}
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
data$construct() data$construct()
vaild_contain_train <- FALSE vaild_contain_train <- FALSE
train_data_name <- "train" train_data_name <- "train"
reduced_valid_sets <- list() reduced_valid_sets <- list()
if(length(valids) > 0){ if (length(valids) > 0) {
for (key in names(valids)) { for (key in names(valids)) {
valid_data <- valids[[key]] valid_data <- valids[[key]]
if(identical(data, valid_data)){ if (identical(data, valid_data)) {
vaild_contain_train <- TRUE vaild_contain_train <- TRUE
train_data_name <- key train_data_name <- key
next next
} }
valid_data$update_params(params) valid_data$update_params(params)
...@@ -113,7 +111,7 @@ lgb.train <- function(params=list(), data, nrounds=10, ...@@ -113,7 +111,7 @@ lgb.train <- function(params=list(), data, nrounds=10,
} }
} }
# process callbacks # process callbacks
if(eval_freq > 0){ if (eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
...@@ -123,55 +121,47 @@ lgb.train <- function(params=list(), data, nrounds=10, ...@@ -123,55 +121,47 @@ lgb.train <- function(params=list(), data, nrounds=10,
# Early stopping callback # Early stopping callback
if (!is.null(early_stopping_rounds)) { if (!is.null(early_stopping_rounds)) {
if(early_stopping_rounds > 0){ if (early_stopping_rounds > 0) {
callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose=verbose)) callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, verbose = verbose))
} }
} }
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
# construct booster # construct booster
booster <- Booster$new(params=params, train_set=data) booster <- Booster$new(params = params, train_set = data)
if(vaild_contain_train){ if (vaild_contain_train) { booster$set_train_data_name(train_data_name) }
booster$set_train_data_name(train_data_name)
}
for (key in names(reduced_valid_sets)) { for (key in names(reduced_valid_sets)) {
booster$add_valid(reduced_valid_sets[[key]], key) booster$add_valid(reduced_valid_sets[[key]], key)
} }
# callback env # callback env
env <- CB_ENV$new()
env <- CB_ENV$new() env$model <- booster
env$model <- booster
env$begin_iteration <- begin_iteration env$begin_iteration <- begin_iteration
env$end_iteration <- end_iteration env$end_iteration <- end_iteration
#start training #start training
for(i in begin_iteration:end_iteration){ for (i in seq(from = begin_iteration, to = end_iteration)) {
env$iteration <- i env$iteration <- i
env$eval_list <- list() env$eval_list <- list()
for (f in cb$pre_iter) f(env) for (f in cb$pre_iter) { f(env) }
# update one iter # update one iter
booster$update(fobj=fobj) booster$update(fobj = fobj)
# collect eval result # collect eval result
eval_list <- list() eval_list <- list()
if(length(valids) > 0){ if (length(valids) > 0) {
if(vaild_contain_train){ if (vaild_contain_train) {
eval_list <- append(eval_list, booster$eval_train(feval=feval)) eval_list <- append(eval_list, booster$eval_train(feval = feval))
} }
eval_list <- append(eval_list, booster$eval_valid(feval=feval)) eval_list <- append(eval_list, booster$eval_valid(feval = feval))
} }
env$eval_list <- eval_list env$eval_list <- eval_list
for (f in cb$post_iter) { f(env) }
for (f in cb$post_iter) f(env)
# met early stopping # met early stopping
if(env$met_early_stop) break if (env$met_early_stop) break
} }
return(booster) booster
} }
# Simple interface for training an lightgbm model. #' Simple interface for training an lightgbm model.
# Its documentation is combined with lgb.train. #' Its documentation is combined with lgb.train.
# #'
#' @rdname lgb.train #' @rdname lgb.train
#' @export #' @export
lightgbm <- function(data, label = NULL, weight = NULL, lightgbm <- function(data, label = NULL, weight = NULL,
params = list(), nrounds=10, params = list(), nrounds = 10,
verbose = 1, eval_freq = 1L, verbose = 1, eval_freq = 1L,
early_stopping_rounds = NULL, early_stopping_rounds = NULL,
save_name = "lightgbm.model", save_name = "lightgbm.model",
init_model = NULL, callbacks = list(), ...) { init_model = NULL, callbacks = list(), ...) {
dtrain <- data dtrain <- data
if(!lgb.is.Dataset(dtrain)) { if (!lgb.is.Dataset(dtrain)) {
dtrain <- lgb.Dataset(data, label=label, weight=weight) dtrain <- lgb.Dataset(data, label = label, weight = weight)
} }
valids <- list() valids <- list()
if (verbose > 0) if (verbose > 0) { valids$train = dtrain }
valids$train = dtrain
bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq=eval_freq, bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq = eval_freq,
early_stopping_rounds = early_stopping_rounds, early_stopping_rounds = early_stopping_rounds,
init_model = init_model, callbacks = callbacks, ...) init_model = init_model, callbacks = callbacks, ...)
bst$save_model(save_name) bst$save_model(save_name)
return(bst) bst
} }
#' Training part from Mushroom Data Set #' Training part from Mushroom Data Set
#' #'
#' This data set is originally from the Mushroom data set, #' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository. #' UCI Machine Learning Repository.
#' #'
#' This data set includes the following fields: #' This data set includes the following fields:
#' #'
#' \itemize{ #' \itemize{
#' \item \code{label} the label for each record #' \item \code{label} the label for each record
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns. #' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
...@@ -39,16 +38,16 @@ lightgbm <- function(data, label = NULL, weight = NULL, ...@@ -39,16 +38,16 @@ lightgbm <- function(data, label = NULL, weight = NULL,
#' #'
#' @references #' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom #' https://archive.ics.uci.edu/ml/datasets/Mushroom
#' #'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository #' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, #' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
#' School of Information and Computer Science. #' School of Information and Computer Science.
#' #'
#' @docType data #' @docType data
#' @keywords datasets #' @keywords datasets
#' @name agaricus.train #' @name agaricus.train
#' @usage data(agaricus.train) #' @usage data(agaricus.train)
#' @format A list containing a label vector, and a dgCMatrix object with 6513 #' @format A list containing a label vector, and a dgCMatrix object with 6513
#' rows and 127 variables #' rows and 127 variables
NULL NULL
...@@ -56,9 +55,9 @@ NULL ...@@ -56,9 +55,9 @@ NULL
#' #'
#' This data set is originally from the Mushroom data set, #' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository. #' UCI Machine Learning Repository.
#' #'
#' This data set includes the following fields: #' This data set includes the following fields:
#' #'
#' \itemize{ #' \itemize{
#' \item \code{label} the label for each record #' \item \code{label} the label for each record
#' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns. #' \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
...@@ -66,16 +65,16 @@ NULL ...@@ -66,16 +65,16 @@ NULL
#' #'
#' @references #' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom #' https://archive.ics.uci.edu/ml/datasets/Mushroom
#' #'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository #' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, #' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
#' School of Information and Computer Science. #' School of Information and Computer Science.
#' #'
#' @docType data #' @docType data
#' @keywords datasets #' @keywords datasets
#' @name agaricus.test #' @name agaricus.test
#' @usage data(agaricus.test) #' @usage data(agaricus.test)
#' @format A list containing a label vector, and a dgCMatrix object with 1611 #' @format A list containing a label vector, and a dgCMatrix object with 1611
#' rows and 126 variables #' rows and 126 variables
NULL NULL
...@@ -83,4 +82,4 @@ NULL ...@@ -83,4 +82,4 @@ NULL
#' @import methods #' @import methods
#' @importFrom R6 R6Class #' @importFrom R6 R6Class
#' @useDynLib lightgbm #' @useDynLib lightgbm
NULL NULL
\ No newline at end of file
lgb.new.handle <- function() { lgb.is.Booster <- function(x) { lgb.check.r6.class(x, "lgb.Booster") }
# use 64bit data to store address
return(0.0) lgb.is.Dataset <- function(x) { lgb.check.r6.class(x, "lgb.Dataset") }
}
lgb.is.null.handle <- function(x) { # use 64bit data to store address
if (is.null(x)) { lgb.new.handle <- function() { 0.0 }
return(TRUE)
} lgb.is.null.handle <- function(x) { is.null(x) || x == 0 }
if (x == 0) {
return(TRUE)
}
return(FALSE)
}
lgb.encode.char <- function(arr, len) { lgb.encode.char <- function(arr, len) {
if (typeof(arr) != "raw") { if (!is.raw(arr)) {
stop("lgb.encode.char: only can encode from raw type") stop("lgb.encode.char: Can only encode from raw type")
} }
return(rawToChar(arr[1:len])) rawToChar(arr[seq_len(len)])
} }
lgb.call <- function(fun_name, ret, ...) { lgb.call <- function(fun_name, ret, ...) {
call_state <- as.integer(0) call_state <- 0L
if (!is.null(ret)) { if (!is.null(ret)) {
call_state <- call_state <- .Call(fun_name, ..., ret, call_state, PACKAGE = "lightgbm")
.Call(fun_name, ..., ret, call_state , PACKAGE = "lightgbm")
} else { } else {
call_state <- .Call(fun_name, ..., call_state , PACKAGE = "lightgbm") call_state <- .Call(fun_name, ..., call_state, PACKAGE = "lightgbm")
} }
if (call_state != as.integer(0)) { if (call_state != 0L) {
buf_len <- as.integer(200) buf_len <- 200L
act_len <- as.integer(0) act_len <- 0L
err_msg <- raw(buf_len) err_msg <- raw(buf_len)
err_msg <- err_msg <- .Call("LGBM_GetLastError_R", buf_len, act_len, err_msg, PACKAGE = "lightgbm")
.Call("LGBM_GetLastError_R", buf_len, act_len, err_msg, PACKAGE = "lightgbm")
if (act_len > buf_len) { if (act_len > buf_len) {
buf_len <- act_len buf_len <- act_len
err_msg <- raw(buf_len) err_msg <- raw(buf_len)
err_msg <- err_msg <- .Call("LGBM_GetLastError_R",
.Call("LGBM_GetLastError_R", buf_len,
buf_len, act_len,
act_len, err_msg,
err_msg, PACKAGE = "lightgbm")
PACKAGE = "lightgbm")
} }
stop(paste0("api error: ", lgb.encode.char(err_msg, act_len))) stop(paste0("api error: ", lgb.encode.char(err_msg, act_len)))
} }
return(ret) ret
} }
lgb.call.return.str <- function(fun_name, ...) { lgb.call.return.str <- function(fun_name, ...) {
buf_len <- as.integer(1024 * 1024) buf_len <- as.integer(1024 * 1024)
act_len <- as.integer(0) act_len <- 0L
buf <- raw(buf_len) buf <- raw(buf_len)
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len) buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
if (act_len > buf_len) { if (act_len > buf_len) {
buf_len <- act_len buf_len <- act_len
buf <- raw(buf_len) buf <- raw(buf_len)
buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len) buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
} }
return(lgb.encode.char(buf, act_len)) lgb.encode.char(buf, act_len)
} }
lgb.params2str <- function(params, ...) { lgb.params2str <- function(params, ...) {
if (typeof(params) != "list") if (!is.list(params)) { stop("params must be a list") }
stop("params must be a list")
names(params) <- gsub("\\.", "_", names(params)) names(params) <- gsub("\\.", "_", names(params))
# merge parameters from the params and the dots-expansion # merge parameters from the params and the dots-expansion
dot_params <- list(...) dot_params <- list(...)
...@@ -72,29 +63,29 @@ lgb.params2str <- function(params, ...) { ...@@ -72,29 +63,29 @@ lgb.params2str <- function(params, ...) {
if (length(intersect(names(params), if (length(intersect(names(params),
names(dot_params))) > 0) names(dot_params))) > 0)
stop( stop(
"Same parameters in 'params' and in the call are not allowed. Please check your 'params' list." "Same parameters in ", sQuote("params"), " and in the call are not allowed. Please check your ", sQuote("params"), " list"
) )
params <- c(params, dot_params) params <- c(params, dot_params)
ret <- list() ret <- list()
for (key in names(params)) { for (key in names(params)) {
# join multi value first # join multi value first
val <- paste0(params[[key]], collapse = ",") val <- paste0(params[[key]], collapse = ",")
if(nchar(val) <= 0) next if (nchar(val) <= 0) next
# join key value # join key value
pair <- paste0(c(key, val), collapse = "=") pair <- paste0(c(key, val), collapse = "=")
ret <- c(ret, pair) ret <- c(ret, pair)
} }
if (length(ret) == 0) { if (length(ret) == 0) {
return(lgb.c_str("")) lgb.c_str("")
} else{ } else {
return(lgb.c_str(paste0(ret, collapse = " "))) lgb.c_str(paste0(ret, collapse = " "))
} }
} }
lgb.c_str <- function(x) { lgb.c_str <- function(x) {
ret <- charToRaw(as.character(x)) ret <- charToRaw(as.character(x))
ret <- c(ret, as.raw(0)) ret <- c(ret, as.raw(0))
return(ret) ret
} }
lgb.check.r6.class <- function(object, name) { lgb.check.r6.class <- function(object, name) {
...@@ -104,54 +95,47 @@ lgb.check.r6.class <- function(object, name) { ...@@ -104,54 +95,47 @@ lgb.check.r6.class <- function(object, name) {
if (!(name %in% class(object))) { if (!(name %in% class(object))) {
return(FALSE) return(FALSE)
} }
return(TRUE) TRUE
} }
lgb.check.params <- function(params){ lgb.check.params <- function(params) {
# To-do # To-do
return(params) params
} }
lgb.check.obj <- function(params, obj) { lgb.check.obj <- function(params, obj) {
if(!is.null(obj)){ OBJECTIVES <- c("regression", "binary", "multiclass", "lambdarank")
params$objective <- obj if (!is.null(obj)) { params$objective <- obj }
} if (is.character(params$objective)) {
if(is.character(params$objective)){ if (!(params$objective %in% OBJECTIVES)) {
if(!(params$objective %in% c("regression", "binary", "multiclass", "lambdarank"))){ stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
stop("lgb.check.obj: objective name error should be (regression, binary, multiclass, lambdarank)")
} }
} else if(typeof(params$objective) != "closure"){ } else if (!is.function(params$objective)) {
stop("lgb.check.obj: objective should be character or function") stop("lgb.check.obj: objective should be a character or a function")
} }
return(params) params
} }
lgb.check.eval <- function(params, eval) { lgb.check.eval <- function(params, eval) {
if(is.null(params$metric)){ if (is.null(params$metric)) { params$metric <- list() }
params$metric <- list() if (!is.null(eval)) {
}
if(!is.null(eval)){
# append metric # append metric
if(is.character(eval) || is.list(eval)){ if (is.character(eval) || is.list(eval)) {
params$metric <- append(params$metric, eval) params$metric <- append(params$metric, eval)
} }
} }
if (typeof(eval) != "closure"){ if (!is.function(eval)) {
if(is.null(params$metric) | length(params$metric) == 0) { if (length(params$metric) == 0) {
# add default metric # add default metric
if(is.character(params$objective)){ params$metric <- switch(
if(params$objective == "regression"){ params$objective,
params$metric <- "l2" regression = "l2",
} else if(params$objective == "binary"){ binary = "binary_logloss",
params$metric <- "binary_logloss" multiclass = "multi_logloss",
} else if(params$objective == "multiclass"){ lambdarank = "ndcg",
params$metric <- "multi_logloss" stop("lgb.check.eval: No default metric available for objective ", sQuote(params$objective))
} else if(params$objective == "lambdarank"){ )
params$metric <- "ndcg"
}
}
} }
} }
return(params) params
} }
basic_walkthrough Basic feature walkthrough basic_walkthrough Basic feature walkthrough
boost_from_prediction Boosting from existing prediction boost_from_prediction Boosting from existing prediction
early_stopping Early Stop in training early_stopping Early Stop in training
cross_validation Cross Validation cross_validation Cross Validation
\ No newline at end of file
...@@ -4,5 +4,3 @@ LightGBM R examples ...@@ -4,5 +4,3 @@ LightGBM R examples
* [Boosting from existing prediction](boost_from_prediction.R) * [Boosting from existing prediction](boost_from_prediction.R)
* [Early Stopping](early_stopping.R) * [Early Stopping](early_stopping.R)
* [Cross Validation](cross_validation.R) * [Cross Validation](cross_validation.R)
...@@ -26,7 +26,7 @@ bst <- lightgbm(data = as.matrix(train$data), label = train$label, num_leaves = ...@@ -26,7 +26,7 @@ bst <- lightgbm(data = as.matrix(train$data), label = train$label, num_leaves =
# you can also put in lgb.Dataset object, which stores label, data and other meta datas needed for advanced features # you can also put in lgb.Dataset object, which stores label, data and other meta datas needed for advanced features
print("Training lightgbm with lgb.Dataset") print("Training lightgbm with lgb.Dataset")
dtrain <- lgb.Dataset(data = train$data, label = train$label) dtrain <- lgb.Dataset(data = train$data, label = train$label)
bst <- lightgbm(data = dtrain, num_leaves = 4, learning_rate = 1, nrounds = 2, bst <- lightgbm(data = dtrain, num_leaves = 4, learning_rate = 1, nrounds = 2,
objective = "binary") objective = "binary")
# Verbose = 0,1,2 # Verbose = 0,1,2
...@@ -46,7 +46,7 @@ bst <- lightgbm(data = dtrain, num_leaves = 4, learning_rate = 1, nrounds = 2, ...@@ -46,7 +46,7 @@ bst <- lightgbm(data = dtrain, num_leaves = 4, learning_rate = 1, nrounds = 2,
#--------------------basic prediction using lightgbm-------------- #--------------------basic prediction using lightgbm--------------
# you can do prediction using the following line # you can do prediction using the following line
# you can put in Matrix, sparseMatrix, or lgb.Dataset # you can put in Matrix, sparseMatrix, or lgb.Dataset
pred <- predict(bst, test$data) pred <- predict(bst, test$data)
err <- mean(as.numeric(pred > 0.5) != test$label) err <- mean(as.numeric(pred > 0.5) != test$label)
print(paste("test-error=", err)) print(paste("test-error=", err))
...@@ -69,7 +69,7 @@ dtest <- lgb.Dataset(data = test$data, label=test$label, free_raw_data=FALSE) ...@@ -69,7 +69,7 @@ dtest <- lgb.Dataset(data = test$data, label=test$label, free_raw_data=FALSE)
# valids is a list of lgb.Dataset, each of them is tagged with name # valids is a list of lgb.Dataset, each of them is tagged with name
valids <- list(train=dtrain, test=dtest) valids <- list(train=dtrain, test=dtest)
# to train with valids, use lgb.train, which contains more advanced features # to train with valids, use lgb.train, which contains more advanced features
# valids allows us to monitor the evaluation result on all data in the list # valids allows us to monitor the evaluation result on all data in the list
print("Train lightgbm using lgb.train with valids") print("Train lightgbm using lgb.train with valids")
bst <- lgb.train(data=dtrain, num_leaves=4, learning_rate=1, nrounds=2, valids=valids, bst <- lgb.train(data=dtrain, num_leaves=4, learning_rate=1, nrounds=2, valids=valids,
nthread = 2, objective = "binary") nthread = 2, objective = "binary")
...@@ -90,5 +90,3 @@ label = getinfo(dtest, "label") ...@@ -90,5 +90,3 @@ label = getinfo(dtest, "label")
pred <- predict(bst, test$data) pred <- predict(bst, test$data)
err <- as.numeric(sum(as.integer(pred > 0.5) != label))/length(label) err <- as.numeric(sum(as.integer(pred > 0.5) != label))/length(label)
print(paste("test-error=", err)) print(paste("test-error=", err))
...@@ -42,4 +42,3 @@ evalerror <- function(preds, dtrain) { ...@@ -42,4 +42,3 @@ evalerror <- function(preds, dtrain) {
# train with customized objective # train with customized objective
lgb.cv(params = param, data = dtrain, nrounds = nround, obj=logregobj, eval=evalerror, nfold = 5) lgb.cv(params = param, data = dtrain, nrounds = nround, obj=logregobj, eval=evalerror, nfold = 5)
...@@ -33,7 +33,6 @@ evalerror <- function(preds, dtrain) { ...@@ -33,7 +33,6 @@ evalerror <- function(preds, dtrain) {
} }
print ('start training with early Stopping setting') print ('start training with early Stopping setting')
bst <- lgb.train(param, dtrain, num_round, valids, bst <- lgb.train(param, dtrain, num_round, valids,
objective = logregobj, eval = evalerror, objective = logregobj, eval = evalerror,
early_stopping_round = 3) early_stopping_round = 3)
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
\name{agaricus.test} \name{agaricus.test}
\alias{agaricus.test} \alias{agaricus.test}
\title{Test part from Mushroom Data Set} \title{Test part from Mushroom Data Set}
\format{A list containing a label vector, and a dgCMatrix object with 1611 \format{A list containing a label vector, and a dgCMatrix object with 1611
rows and 126 variables} rows and 126 variables}
\usage{ \usage{
data(agaricus.test) data(agaricus.test)
...@@ -24,8 +24,8 @@ This data set includes the following fields: ...@@ -24,8 +24,8 @@ This data set includes the following fields:
\references{ \references{
https://archive.ics.uci.edu/ml/datasets/Mushroom https://archive.ics.uci.edu/ml/datasets/Mushroom
Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
School of Information and Computer Science. School of Information and Computer Science.
} }
\keyword{datasets} \keyword{datasets}
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
\name{agaricus.train} \name{agaricus.train}
\alias{agaricus.train} \alias{agaricus.train}
\title{Training part from Mushroom Data Set} \title{Training part from Mushroom Data Set}
\format{A list containing a label vector, and a dgCMatrix object with 6513 \format{A list containing a label vector, and a dgCMatrix object with 6513
rows and 127 variables} rows and 127 variables}
\usage{ \usage{
data(agaricus.train) data(agaricus.train)
...@@ -24,8 +24,8 @@ This data set includes the following fields: ...@@ -24,8 +24,8 @@ This data set includes the following fields:
\references{ \references{
https://archive.ics.uci.edu/ml/datasets/Mushroom https://archive.ics.uci.edu/ml/datasets/Mushroom
Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
[http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
School of Information and Computer Science. School of Information and Computer Science.
} }
\keyword{datasets} \keyword{datasets}
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
% Please edit documentation in R/lgb.Dataset.R % Please edit documentation in R/lgb.Dataset.R
\name{dim.lgb.Dataset} \name{dim.lgb.Dataset}
\alias{dim.lgb.Dataset} \alias{dim.lgb.Dataset}
\title{Dimensions of lgb.Dataset} \title{Dimensions of an lgb.Dataset}
\usage{ \usage{
\method{dim}{lgb.Dataset}(x, ...) \method{dim}{lgb.Dataset}(x, ...)
} }
...@@ -15,23 +15,21 @@ ...@@ -15,23 +15,21 @@
a vector of numbers of rows and of columns a vector of numbers of rows and of columns
} }
\description{ \description{
Dimensions of lgb.Dataset Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
} }
\details{ \details{
Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also
be directly used with an \code{lgb.Dataset} object. be directly used with an \code{lgb.Dataset} object.
} }
\examples{ \examples{
data(agaricus.train, package='lightgbm') dontrun{
train <- agaricus.train data(agaricus.train, package='lightgbm')
dtrain <- lgb.Dataset(train$data, label=train$label) train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label=train$label)
stopifnot(nrow(dtrain) == nrow(train$data))
stopifnot(ncol(dtrain) == ncol(train$data))
stopifnot(all(dim(dtrain) == dim(train$data)))
stopifnot(nrow(dtrain) == nrow(train$data))
stopifnot(ncol(dtrain) == ncol(train$data))
stopifnot(all(dim(dtrain) == dim(train$data)))
}
} }
...@@ -16,25 +16,23 @@ ...@@ -16,25 +16,23 @@
and the second one is column names} and the second one is column names}
} }
\description{ \description{
Handling of column names of \code{lgb.Dataset} Only column names are supported for \code{lgb.Dataset}, thus setting of
row names would have no effect and returned row names would be NULL.
} }
\details{ \details{
Only column names are supported for \code{lgb.Dataset}, thus setting of
row names would have no effect and returnten row names would be NULL.
Generic \code{dimnames} methods are used by \code{colnames}. Generic \code{dimnames} methods are used by \code{colnames}.
Since row names are irrelevant, it is recommended to use \code{colnames} directly. Since row names are irrelevant, it is recommended to use \code{colnames} directly.
} }
\examples{ \examples{
data(agaricus.train, package='lightgbm') dontrun{
train <- agaricus.train data(agaricus.train, package='lightgbm')
dtrain <- lgb.Dataset(train$data, label=train$label) train <- agaricus.train
lgb.Dataset.construct(dtrain) dtrain <- lgb.Dataset(train$data, label=train$label)
dimnames(dtrain) lgb.Dataset.construct(dtrain)
colnames(dtrain) dimnames(dtrain)
colnames(dtrain) <- make.names(1:ncol(train$data)) colnames(dtrain)
print(dtrain, verbose=TRUE) colnames(dtrain) <- make.names(1:ncol(train$data))
print(dtrain, verbose=TRUE)
}
} }
...@@ -33,14 +33,16 @@ The \code{name} field can be one of the following: ...@@ -33,14 +33,16 @@ The \code{name} field can be one of the following:
} }
} }
\examples{ \examples{
data(agaricus.train, package='lightgbm') \dontrun{
train <- agaricus.train data(agaricus.train, package='lightgbm')
dtrain <- lgb.Dataset(train$data, label=train$label) train <- agaricus.train
lgb.Dataset.construct(dtrain) dtrain <- lgb.Dataset(train$data, label=train$label)
labels <- getinfo(dtrain, 'label') lgb.Dataset.construct(dtrain)
setinfo(dtrain, 'label', 1-labels) labels <- getinfo(dtrain, 'label')
setinfo(dtrain, 'label', 1-labels)
labels2 <- getinfo(dtrain, 'label') labels2 <- getinfo(dtrain, 'label')
stopifnot(all(labels2 == 1-labels)) stopifnot(all(labels2 == 1-labels))
}
} }
...@@ -28,18 +28,17 @@ lgb.Dataset(data, params = list(), reference = NULL, colnames = NULL, ...@@ -28,18 +28,17 @@ lgb.Dataset(data, params = list(), reference = NULL, colnames = NULL,
constructed dataset constructed dataset
} }
\description{ \description{
Contruct lgb.Dataset object
}
\details{
Contruct lgb.Dataset object from dense matrix, sparse matrix Contruct lgb.Dataset object from dense matrix, sparse matrix
or local file (that was created previously by saving an \code{lgb.Dataset}). or local file (that was created previously by saving an \code{lgb.Dataset}).
} }
\examples{ \examples{
data(agaricus.train, package='lightgbm') \dontrun{
train <- agaricus.train data(agaricus.train, package='lightgbm')
dtrain <- lgb.Dataset(train$data, label=train$label) train <- agaricus.train
lgb.Dataset.save(dtrain, 'lgb.Dataset.data') dtrain <- lgb.Dataset(train$data, label=train$label)
dtrain <- lgb.Dataset('lgb.Dataset.data') lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
lgb.Dataset.construct(dtrain) dtrain <- lgb.Dataset('lgb.Dataset.data')
lgb.Dataset.construct(dtrain)
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment