Commit 455ba60f authored by Guolin Ke's avatar Guolin Ke
Browse files

[R-package] separate the verbose and record

parent 0ddb3efd
......@@ -92,7 +92,8 @@ cb.print.evaluation <- function(period = 1){
if ( (i - 1) %% period == 0
| i == env$begin_iteration
| i == env$end_iteration ) {
cat(merge.eval.string(env), "\n")
msg <- merge.eval.string(env)
if (nchar(msg) > 0) { cat(merge.eval.string(env), "\n") }
}
}
}
......
......@@ -6,7 +6,6 @@ Booster <- R6Class(
record_evals = list(),
finalize = function() {
if (!lgb.is.null.handle(private$handle)) {
cat("freeing booster handle\n")
lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL
}
......@@ -50,7 +49,7 @@ Booster <- R6Class(
}
class(handle) <- "lgb.Booster.handle"
private$handle <- handle
private$num_class <- as.integer(1)
private$num_class <- 1L
private$num_class <-
lgb.call("LGBM_BoosterGetNumClasses_R", ret = private$num_class, private$handle)
},
......@@ -107,6 +106,10 @@ Booster <- R6Class(
} else {
if (!is.function(fobj)) { stop("lgb.Booster.update: fobj should be a function") }
gpair <- fobj(private$inner_predict(1), private$train_set)
if(is.null(gpair$grad) | is.null(gpair$hess)){
stop("lgb.Booster.update: custom objective should
return a list with attributes (hess, grad)")
}
ret <- lgb.call(
"LGBM_BoosterUpdateOneIterCustom_R", ret = NULL,
private$handle,
......@@ -128,7 +131,7 @@ Booster <- R6Class(
self
},
current_iter = function() {
cur_iter <- as.integer(0)
cur_iter <- 0L
lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
},
eval = function(data, name, feval = NULL) {
......@@ -214,7 +217,7 @@ Booster <- R6Class(
stop("data_idx should not be greater than num_dataset")
}
if (is.null(private$predict_buffer[[data_name]])) {
npred <- as.integer(0)
npred <- 0L
npred <- lgb.call("LGBM_BoosterGetNumPredict_R",
ret = npred,
private$handle,
......@@ -275,7 +278,12 @@ Booster <- R6Class(
}
data <- private$train_set
if (data_idx > 1) { data <- private$valid_sets[[data_idx - 1]] }
res <- feval(private$inner_predict(data_idx), data)
res <- feval(private$inner_predict(data_idx), data)
if(is.null(res$name) | is.null(res$value) |
is.null(res$higher_better)) {
stop("lgb.Booster.eval: custom eval function should return a
list with attribute (name, value, higher_better)");
}
res$data_name <- data_name
ret <- append(ret, list(res))
}
......
......@@ -4,7 +4,6 @@ Dataset <- R6Class(
public = list(
finalize = function() {
if (!lgb.is.null.handle(private$handle)) {
cat("free dataset handle\n")
lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle)
private$handle <- NULL
}
......@@ -200,8 +199,8 @@ Dataset <- R6Class(
},
dim = function() {
if (!lgb.is.null.handle(private$handle)) {
num_row <- as.integer(0)
num_col <- as.integer(0)
num_row <- 0L
num_col <- 0L
c(
lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle),
......@@ -252,7 +251,7 @@ Dataset <- R6Class(
)
}
if (is.null(private$info[[name]]) && !lgb.is.null.handle(private$handle)) {
info_len <- as.integer(0)
info_len <- 0L
info_len <- lgb.call("LGBM_DatasetGetFieldSize_R",
ret = info_len,
private$handle,
......
......@@ -31,7 +31,7 @@ Predictor <- R6Class(
predleaf = FALSE, header = FALSE, reshape = FALSE) {
if (is.null(num_iteration)) { num_iteration <- -1 }
num_row <- 0
num_row <- 0L
if (is.character(data)) {
tmp_filename <- tempfile(pattern = "lightgbm_")
on.exit(unlink(tmp_filename), add = TRUE)
......@@ -46,7 +46,7 @@ Predictor <- R6Class(
preds <- as.vector(t(preds))
} else {
num_row <- nrow(data)
npred <- as.integer(0)
npred <- 0L
npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret = npred,
private$handle,
as.integer(num_row),
......
......@@ -28,9 +28,9 @@ CVBooster <- R6Class(
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function
#' @param eval evaluation function, can be (list of) character or custom eval function
#' @param verbose verbosity for output
#' if verbose > 0 , also will record iteration message to booster$record_evals
#' @param eval_freq evalutaion output frequence
#' @param verbose verbosity for output, if <= 0, also will disable the print of evalutaion during training
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param eval_freq evalutaion output frequence, only effect when verbose > 0
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
#' by the values of outcome labels.
......@@ -51,7 +51,7 @@ CVBooster <- R6Class(
#' @param callbacks list of callback functions
#' List of callback functions that are applied at each iteration.
#' @param ... other parameters, see parameters.md for more informations
#' @return a trained booster model \code{lgb.Booster}.
#' @return a trained model \code{lgb.CVBooster}.
#' @examples
#' \dontrun{
#' library(lightgbm)
......@@ -63,16 +63,23 @@ CVBooster <- R6Class(
#' }
#' @rdname lgb.train
#' @export
lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3,
label = NULL, weight = NULL,
obj = NULL, eval = NULL,
verbose = 1, eval_freq = 1L, showsd = TRUE,
stratified = TRUE, folds = NULL,
init_model = NULL,
colnames= NULL,
categorical_feature = NULL,
early_stopping_rounds = NULL,
callbacks = list(), ...) {
lgb.cv <- function(params=list(), data, nrounds = 10,
nfold = 3,
label = NULL,
weight = NULL,
obj = NULL,
eval = NULL,
verbose = 1,
record = TRUE,
eval_freq = 1L,
showsd = TRUE,
stratified = TRUE,
folds = NULL,
init_model = NULL,
colnames = NULL,
categorical_feature = NULL,
early_stopping_rounds = NULL,
callbacks = list(), ...) {
addiction_params <- list(...)
params <- append(params, addiction_params)
params$verbose <- verbose
......@@ -112,7 +119,7 @@ lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3,
data$construct()
if (!is.null(folds)) {
if (!is.list(folds) || length(folds) < 2)
if (!is.list(folds) | length(folds) < 2)
stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold")
nfold <- length(folds)
} else {
......@@ -120,11 +127,11 @@ lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3,
folds <- generate.cv.folds(nfold, nrow(data), stratified, getinfo(data, 'label'), params)
}
if (eval_freq > 0) {
if (verbose > 0 & eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
}
if (verbose > 0) { callbacks <- add.cb(callbacks, cb.record.evaluation()) }
if (record) { callbacks <- add.cb(callbacks, cb.record.evaluation()) }
if (!is.null(early_stopping_rounds)) {
if (early_stopping_rounds > 0) {
......
......@@ -6,9 +6,9 @@
#' @param valids a list of \code{lgb.Dataset} objects, used for validation
#' @param obj objective function, can be character or custom objective function
#' @param eval evaluation function, can be (a list of) character or custom eval function
#' @param verbose verbosity for output
#' if \code{verbose > 0}, also will record iteration message to \code{booster$record_evals}
#' @param eval_freq evalutaion output frequency
#' @param verbose verbosity for output, if <= 0, also will disable the print of evalutaion during training
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param eval_freq evalutaion output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int
......@@ -44,6 +44,7 @@ lgb.train <- function(params = list(), data, nrounds = 10,
obj = NULL,
eval = NULL,
verbose = 1,
record = TRUE,
eval_freq = 1L,
init_model = NULL,
colnames = NULL,
......@@ -111,11 +112,11 @@ lgb.train <- function(params = list(), data, nrounds = 10,
}
}
# process callbacks
if (eval_freq > 0) {
if (verbose > 0 & eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
}
if (verbose > 0 && length(valids) > 0) {
if (record & length(valids) > 0) {
callbacks <- add.cb(callbacks, cb.record.evaluation())
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment