Commit 455ba60f authored by Guolin Ke's avatar Guolin Ke
Browse files

[R-package] separate the verbose and record

parent 0ddb3efd
...@@ -92,7 +92,8 @@ cb.print.evaluation <- function(period = 1){ ...@@ -92,7 +92,8 @@ cb.print.evaluation <- function(period = 1){
if ( (i - 1) %% period == 0 if ( (i - 1) %% period == 0
| i == env$begin_iteration | i == env$begin_iteration
| i == env$end_iteration ) { | i == env$end_iteration ) {
cat(merge.eval.string(env), "\n") msg <- merge.eval.string(env)
if (nchar(msg) > 0) { cat(merge.eval.string(env), "\n") }
} }
} }
} }
......
...@@ -6,7 +6,6 @@ Booster <- R6Class( ...@@ -6,7 +6,6 @@ Booster <- R6Class(
record_evals = list(), record_evals = list(),
finalize = function() { finalize = function() {
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
cat("freeing booster handle\n")
lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle) lgb.call("LGBM_BoosterFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
...@@ -50,7 +49,7 @@ Booster <- R6Class( ...@@ -50,7 +49,7 @@ Booster <- R6Class(
} }
class(handle) <- "lgb.Booster.handle" class(handle) <- "lgb.Booster.handle"
private$handle <- handle private$handle <- handle
private$num_class <- as.integer(1) private$num_class <- 1L
private$num_class <- private$num_class <-
lgb.call("LGBM_BoosterGetNumClasses_R", ret = private$num_class, private$handle) lgb.call("LGBM_BoosterGetNumClasses_R", ret = private$num_class, private$handle)
}, },
...@@ -107,6 +106,10 @@ Booster <- R6Class( ...@@ -107,6 +106,10 @@ Booster <- R6Class(
} else { } else {
if (!is.function(fobj)) { stop("lgb.Booster.update: fobj should be a function") } if (!is.function(fobj)) { stop("lgb.Booster.update: fobj should be a function") }
gpair <- fobj(private$inner_predict(1), private$train_set) gpair <- fobj(private$inner_predict(1), private$train_set)
if(is.null(gpair$grad) | is.null(gpair$hess)){
stop("lgb.Booster.update: custom objective should
return a list with attributes (hess, grad)")
}
ret <- lgb.call( ret <- lgb.call(
"LGBM_BoosterUpdateOneIterCustom_R", ret = NULL, "LGBM_BoosterUpdateOneIterCustom_R", ret = NULL,
private$handle, private$handle,
...@@ -128,7 +131,7 @@ Booster <- R6Class( ...@@ -128,7 +131,7 @@ Booster <- R6Class(
self self
}, },
current_iter = function() { current_iter = function() {
cur_iter <- as.integer(0) cur_iter <- 0L
lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle) lgb.call("LGBM_BoosterGetCurrentIteration_R", ret = cur_iter, private$handle)
}, },
eval = function(data, name, feval = NULL) { eval = function(data, name, feval = NULL) {
...@@ -214,7 +217,7 @@ Booster <- R6Class( ...@@ -214,7 +217,7 @@ Booster <- R6Class(
stop("data_idx should not be greater than num_dataset") stop("data_idx should not be greater than num_dataset")
} }
if (is.null(private$predict_buffer[[data_name]])) { if (is.null(private$predict_buffer[[data_name]])) {
npred <- as.integer(0) npred <- 0L
npred <- lgb.call("LGBM_BoosterGetNumPredict_R", npred <- lgb.call("LGBM_BoosterGetNumPredict_R",
ret = npred, ret = npred,
private$handle, private$handle,
...@@ -276,6 +279,11 @@ Booster <- R6Class( ...@@ -276,6 +279,11 @@ Booster <- R6Class(
data <- private$train_set data <- private$train_set
if (data_idx > 1) { data <- private$valid_sets[[data_idx - 1]] } if (data_idx > 1) { data <- private$valid_sets[[data_idx - 1]] }
res <- feval(private$inner_predict(data_idx), data) res <- feval(private$inner_predict(data_idx), data)
if(is.null(res$name) | is.null(res$value) |
is.null(res$higher_better)) {
stop("lgb.Booster.eval: custom eval function should return a
list with attribute (name, value, higher_better)");
}
res$data_name <- data_name res$data_name <- data_name
ret <- append(ret, list(res)) ret <- append(ret, list(res))
} }
......
...@@ -4,7 +4,6 @@ Dataset <- R6Class( ...@@ -4,7 +4,6 @@ Dataset <- R6Class(
public = list( public = list(
finalize = function() { finalize = function() {
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
cat("free dataset handle\n")
lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle) lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle)
private$handle <- NULL private$handle <- NULL
} }
...@@ -200,8 +199,8 @@ Dataset <- R6Class( ...@@ -200,8 +199,8 @@ Dataset <- R6Class(
}, },
dim = function() { dim = function() {
if (!lgb.is.null.handle(private$handle)) { if (!lgb.is.null.handle(private$handle)) {
num_row <- as.integer(0) num_row <- 0L
num_col <- as.integer(0) num_col <- 0L
c( c(
lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle), lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle),
...@@ -252,7 +251,7 @@ Dataset <- R6Class( ...@@ -252,7 +251,7 @@ Dataset <- R6Class(
) )
} }
if (is.null(private$info[[name]]) && !lgb.is.null.handle(private$handle)) { if (is.null(private$info[[name]]) && !lgb.is.null.handle(private$handle)) {
info_len <- as.integer(0) info_len <- 0L
info_len <- lgb.call("LGBM_DatasetGetFieldSize_R", info_len <- lgb.call("LGBM_DatasetGetFieldSize_R",
ret = info_len, ret = info_len,
private$handle, private$handle,
......
...@@ -31,7 +31,7 @@ Predictor <- R6Class( ...@@ -31,7 +31,7 @@ Predictor <- R6Class(
predleaf = FALSE, header = FALSE, reshape = FALSE) { predleaf = FALSE, header = FALSE, reshape = FALSE) {
if (is.null(num_iteration)) { num_iteration <- -1 } if (is.null(num_iteration)) { num_iteration <- -1 }
num_row <- 0 num_row <- 0L
if (is.character(data)) { if (is.character(data)) {
tmp_filename <- tempfile(pattern = "lightgbm_") tmp_filename <- tempfile(pattern = "lightgbm_")
on.exit(unlink(tmp_filename), add = TRUE) on.exit(unlink(tmp_filename), add = TRUE)
...@@ -46,7 +46,7 @@ Predictor <- R6Class( ...@@ -46,7 +46,7 @@ Predictor <- R6Class(
preds <- as.vector(t(preds)) preds <- as.vector(t(preds))
} else { } else {
num_row <- nrow(data) num_row <- nrow(data)
npred <- as.integer(0) npred <- 0L
npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret = npred, npred <- lgb.call("LGBM_BoosterCalcNumPredict_R", ret = npred,
private$handle, private$handle,
as.integer(num_row), as.integer(num_row),
......
...@@ -28,9 +28,9 @@ CVBooster <- R6Class( ...@@ -28,9 +28,9 @@ CVBooster <- R6Class(
#' @param weight vector of response values. If not NULL, will set to dataset #' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function #' @param obj objective function, can be character or custom objective function
#' @param eval evaluation function, can be (list of) character or custom eval function #' @param eval evaluation function, can be (list of) character or custom eval function
#' @param verbose verbosity for output #' @param verbose verbosity for output, if <= 0, also will disable the print of evalutaion during training
#' if verbose > 0 , also will record iteration message to booster$record_evals #' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param eval_freq evalutaion output frequence #' @param eval_freq evalutaion output frequence, only effect when verbose > 0
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation #' @param showsd \code{boolean}, whether to show standard deviation of cross validation
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified #' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
#' by the values of outcome labels. #' by the values of outcome labels.
...@@ -51,7 +51,7 @@ CVBooster <- R6Class( ...@@ -51,7 +51,7 @@ CVBooster <- R6Class(
#' @param callbacks list of callback functions #' @param callbacks list of callback functions
#' List of callback functions that are applied at each iteration. #' List of callback functions that are applied at each iteration.
#' @param ... other parameters, see parameters.md for more informations #' @param ... other parameters, see parameters.md for more informations
#' @return a trained booster model \code{lgb.Booster}. #' @return a trained model \code{lgb.CVBooster}.
#' @examples #' @examples
#' \dontrun{ #' \dontrun{
#' library(lightgbm) #' library(lightgbm)
...@@ -63,13 +63,20 @@ CVBooster <- R6Class( ...@@ -63,13 +63,20 @@ CVBooster <- R6Class(
#' } #' }
#' @rdname lgb.train #' @rdname lgb.train
#' @export #' @export
lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3, lgb.cv <- function(params=list(), data, nrounds = 10,
label = NULL, weight = NULL, nfold = 3,
obj = NULL, eval = NULL, label = NULL,
verbose = 1, eval_freq = 1L, showsd = TRUE, weight = NULL,
stratified = TRUE, folds = NULL, obj = NULL,
eval = NULL,
verbose = 1,
record = TRUE,
eval_freq = 1L,
showsd = TRUE,
stratified = TRUE,
folds = NULL,
init_model = NULL, init_model = NULL,
colnames= NULL, colnames = NULL,
categorical_feature = NULL, categorical_feature = NULL,
early_stopping_rounds = NULL, early_stopping_rounds = NULL,
callbacks = list(), ...) { callbacks = list(), ...) {
...@@ -112,7 +119,7 @@ lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3, ...@@ -112,7 +119,7 @@ lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3,
data$construct() data$construct()
if (!is.null(folds)) { if (!is.null(folds)) {
if (!is.list(folds) || length(folds) < 2) if (!is.list(folds) | length(folds) < 2)
stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold") stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold")
nfold <- length(folds) nfold <- length(folds)
} else { } else {
...@@ -120,11 +127,11 @@ lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3, ...@@ -120,11 +127,11 @@ lgb.cv <- function(params=list(), data, nrounds = 10, nfold = 3,
folds <- generate.cv.folds(nfold, nrow(data), stratified, getinfo(data, 'label'), params) folds <- generate.cv.folds(nfold, nrow(data), stratified, getinfo(data, 'label'), params)
} }
if (eval_freq > 0) { if (verbose > 0 & eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
if (verbose > 0) { callbacks <- add.cb(callbacks, cb.record.evaluation()) } if (record) { callbacks <- add.cb(callbacks, cb.record.evaluation()) }
if (!is.null(early_stopping_rounds)) { if (!is.null(early_stopping_rounds)) {
if (early_stopping_rounds > 0) { if (early_stopping_rounds > 0) {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
#' @param valids a list of \code{lgb.Dataset} objects, used for validation #' @param valids a list of \code{lgb.Dataset} objects, used for validation
#' @param obj objective function, can be character or custom objective function #' @param obj objective function, can be character or custom objective function
#' @param eval evaluation function, can be (a list of) character or custom eval function #' @param eval evaluation function, can be (a list of) character or custom eval function
#' @param verbose verbosity for output #' @param verbose verbosity for output, if <= 0, also will disable the print of evalutaion during training
#' if \code{verbose > 0}, also will record iteration message to \code{booster$record_evals} #' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param eval_freq evalutaion output frequency #' @param eval_freq evalutaion output frequency, only effect when verbose > 0
#' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model #' @param init_model path of model file of \code{lgb.Booster} object, will continue training from this model
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset #' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature list of str or int #' @param categorical_feature list of str or int
...@@ -44,6 +44,7 @@ lgb.train <- function(params = list(), data, nrounds = 10, ...@@ -44,6 +44,7 @@ lgb.train <- function(params = list(), data, nrounds = 10,
obj = NULL, obj = NULL,
eval = NULL, eval = NULL,
verbose = 1, verbose = 1,
record = TRUE,
eval_freq = 1L, eval_freq = 1L,
init_model = NULL, init_model = NULL,
colnames = NULL, colnames = NULL,
...@@ -111,11 +112,11 @@ lgb.train <- function(params = list(), data, nrounds = 10, ...@@ -111,11 +112,11 @@ lgb.train <- function(params = list(), data, nrounds = 10,
} }
} }
# process callbacks # process callbacks
if (eval_freq > 0) { if (verbose > 0 & eval_freq > 0) {
callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq))
} }
if (verbose > 0 && length(valids) > 0) { if (record & length(valids) > 0) {
callbacks <- add.cb(callbacks, cb.record.evaluation()) callbacks <- add.cb(callbacks, cb.record.evaluation())
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment