Unverified Commit 1d8bfd8c authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] respect 'verbose' argument in lgb.cv() (fixes #4667) (#4903)



* fixes

* revert debugging code

* add test

* check for LightGBM explicitly

* empty commit

* revert unnecessary line deletion

* respect verbose everywhere and update params for constructted dataset

* Update R-package/R/lgb.Dataset.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 100a10da
...@@ -109,7 +109,7 @@ Dataset <- R6::R6Class( ...@@ -109,7 +109,7 @@ Dataset <- R6::R6Class(
params = list()) { params = list()) {
# the Dataset's existing parameters should be overwritten by any passed in to this call # the Dataset's existing parameters should be overwritten by any passed in to this call
params <- modifyList(self$get_params(), params) params <- modifyList(private$params, params)
# Create new dataset # Create new dataset
ret <- Dataset$new( ret <- Dataset$new(
...@@ -535,7 +535,7 @@ Dataset <- R6::R6Class( ...@@ -535,7 +535,7 @@ Dataset <- R6::R6Class(
return( return(
Dataset$new( Dataset$new(
data = NULL data = NULL
, params = self$get_params() , params = private$params
, reference = self , reference = self
, colnames = private$colnames , colnames = private$colnames
, categorical_feature = private$categorical_feature , categorical_feature = private$categorical_feature
...@@ -554,15 +554,17 @@ Dataset <- R6::R6Class( ...@@ -554,15 +554,17 @@ Dataset <- R6::R6Class(
if (length(params) == 0L) { if (length(params) == 0L) {
return(invisible(self)) return(invisible(self))
} }
new_params <- utils::modifyList(private$params, params)
if (lgb.is.null.handle(x = private$handle)) { if (lgb.is.null.handle(x = private$handle)) {
private$params <- utils::modifyList(private$params, params) private$params <- new_params
} else { } else {
tryCatch({ tryCatch({
.Call( .Call(
LGBM_DatasetUpdateParamChecking_R LGBM_DatasetUpdateParamChecking_R
, lgb.params2str(params = private$params) , lgb.params2str(params = private$params)
, lgb.params2str(params = params) , lgb.params2str(params = new_params)
) )
private$params <- new_params
}, error = function(e) { }, error = function(e) {
# If updating failed but raw data is not available, raise an error because # If updating failed but raw data is not available, raise an error because
# achieving what the user asked for is not possible # achieving what the user asked for is not possible
...@@ -572,7 +574,7 @@ Dataset <- R6::R6Class( ...@@ -572,7 +574,7 @@ Dataset <- R6::R6Class(
# If updating failed but raw data is available, modify the params # If updating failed but raw data is available, modify the params
# on the R side and re-set ("deconstruct") the Dataset # on the R side and re-set ("deconstruct") the Dataset
private$params <- utils::modifyList(private$params, params) private$params <- new_params
self$finalize() self$finalize()
}) })
} }
...@@ -580,6 +582,11 @@ Dataset <- R6::R6Class( ...@@ -580,6 +582,11 @@ Dataset <- R6::R6Class(
}, },
# [description] Get only Dataset-specific parameters. This is primarily used by
# Booster to update its parameters based on the characteristics of
# a Dataset. It should not be used by other methods in this class,
# since "verbose" is not a Dataset parameter and needs to be passed
# through to avoid globally re-setting verbosity.
get_params = function() { get_params = function() {
dataset_params <- unname(unlist(.DATASET_PARAMETERS())) dataset_params <- unname(unlist(.DATASET_PARAMETERS()))
ret <- list() ret <- list()
......
...@@ -2155,6 +2155,56 @@ test_that("early stopping works with lgb.cv()", { ...@@ -2155,6 +2155,56 @@ test_that("early stopping works with lgb.cv()", {
) )
}) })
test_that("lgb.cv() respects changes to logging verbosity", {
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
# (verbose = 1) should be INFO and WARNING level logs
lgb_cv_logs <- capture.output({
cv_bst <- lgb.cv(
params = list()
, nfold = 2L
, nrounds = 5L
, data = dtrain
, obj = "binary"
, verbose = 1L
)
})
expect_true(any(grepl("\\[LightGBM\\] \\[Info\\]", lgb_cv_logs)))
expect_true(any(grepl("\\[LightGBM\\] \\[Warning\\]", lgb_cv_logs)))
# (verbose = 0) should be WARNING level logs only
lgb_cv_logs <- capture.output({
cv_bst <- lgb.cv(
params = list()
, nfold = 2L
, nrounds = 5L
, data = dtrain
, obj = "binary"
, verbose = 0L
)
})
expect_false(any(grepl("\\[LightGBM\\] \\[Info\\]", lgb_cv_logs)))
expect_true(any(grepl("\\[LightGBM\\] \\[Warning\\]", lgb_cv_logs)))
# (verbose = -1) no logs
lgb_cv_logs <- capture.output({
cv_bst <- lgb.cv(
params = list()
, nfold = 2L
, nrounds = 5L
, data = dtrain
, obj = "binary"
, verbose = -1L
)
})
# NOTE: this is not length(lgb_cv_logs) == 0 because lightgbm's
# dependencies might print other messages
expect_false(any(grepl("\\[LightGBM\\] \\[Info\\]", lgb_cv_logs)))
expect_false(any(grepl("\\[LightGBM\\] \\[Warning\\]", lgb_cv_logs)))
})
test_that("lgb.cv() updates params based on keyword arguments", { test_that("lgb.cv() updates params based on keyword arguments", {
dtrain <- lgb.Dataset( dtrain <- lgb.Dataset(
data = matrix(rnorm(400L), ncol = 4L) data = matrix(rnorm(400L), ncol = 4L)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment