Unverified Commit 532fa914 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] allow access to params in Booster (#3662)

* [R-package] allow access to params in Booster

* remove unnecessary whitespace

* fix test on resetting params

* remove pytest_cache

* Update R-package/tests/testthat/test_custom_objective.R
parent d7a384fa
...@@ -6,6 +6,7 @@ Booster <- R6::R6Class( ...@@ -6,6 +6,7 @@ Booster <- R6::R6Class(
best_iter = -1L, best_iter = -1L,
best_score = NA_real_, best_score = NA_real_,
params = list(),
record_evals = list(), record_evals = list(),
# Finalize will free up the handles # Finalize will free up the handles
...@@ -134,6 +135,8 @@ Booster <- R6::R6Class( ...@@ -134,6 +135,8 @@ Booster <- R6::R6Class(
} }
self$params <- params
}, },
# Set training data name # Set training data name
...@@ -187,17 +190,20 @@ Booster <- R6::R6Class( ...@@ -187,17 +190,20 @@ Booster <- R6::R6Class(
# Reset parameters of booster # Reset parameters of booster
reset_parameter = function(params, ...) { reset_parameter = function(params, ...) {
# Append parameters if (methods::is(self$params, "list")) {
params <- append(params, list(...)) params <- modifyList(self$params, params)
}
params <- modifyList(params, list(...))
params_str <- lgb.params2str(params = params) params_str <- lgb.params2str(params = params)
# Reset parameters
lgb.call( lgb.call(
fun_name = "LGBM_BoosterResetParameter_R" fun_name = "LGBM_BoosterResetParameter_R"
, ret = NULL , ret = NULL
, private$handle , private$handle
, params_str , params_str
) )
self$params <- params
return(invisible(self)) return(invisible(self))
......
...@@ -44,6 +44,7 @@ readRDS.lgb.Booster <- function(file = "", refhook = NULL) { ...@@ -44,6 +44,7 @@ readRDS.lgb.Booster <- function(file = "", refhook = NULL) {
# Restore best iteration and recorded evaluations # Restore best iteration and recorded evaluations
object2$best_iter <- object$best_iter object2$best_iter <- object$best_iter
object2$record_evals <- object$record_evals object2$record_evals <- object$record_evals
object2$params <- object$params
# Return newly loaded object # Return newly loaded object
return(object2) return(object2)
......
...@@ -386,6 +386,75 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat ...@@ -386,6 +386,75 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE) }, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
}) })
test_that("Booster should store parameters and Booster$reset_parameter() should update them", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
)
# testing that this works for some cases that could break it:
# - multiple metrics
# - using "metric", "boosting", "num_class" in params
params <- list(
objective = "multiclass"
, max_depth = 4L
, bagging_fraction = 0.8
, metric = c("multi_logloss", "multi_error")
, boosting = "gbdt"
, num_class = 5L
)
bst <- Booster$new(
params = params
, train_set = dtrain
)
expect_identical(bst$params, params)
params[["bagging_fraction"]] <- 0.9
ret_bst <- bst$reset_parameter(params = params)
expect_identical(ret_bst$params, params)
expect_identical(bst$params, params)
})
test_that("Booster$params should include dataset params, before and after Booster$reset_parameter()", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
, params = list(
max_bin = 17L
)
)
params <- list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
)
bst <- Booster$new(
params = params
, train_set = dtrain
)
expect_identical(
bst$params
, list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
, max_bin = 17L
)
)
params[["bagging_fraction"]] <- 0.9
ret_bst <- bst$reset_parameter(params = params)
expected_params <- list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.9
, max_bin = 17L
)
expect_identical(ret_bst$params, expected_params)
expect_identical(bst$params, expected_params)
})
context("save_model") context("save_model")
test_that("Saving a model with different feature importance types works", { test_that("Saving a model with different feature importance types works", {
...@@ -626,3 +695,38 @@ test_that("lgb.cv() correctly handles passing through params to the model file", ...@@ -626,3 +695,38 @@ test_that("lgb.cv() correctly handles passing through params to the model file",
} }
}) })
context("saveRDS.lgb.Booster() and readRDS.lgb.Booster()")
test_that("params (including dataset params) should be stored in .rds file for Booster", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
, params = list(
max_bin = 17L
)
)
params <- list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
)
bst <- Booster$new(
params = params
, train_set = dtrain
)
bst_file <- tempfile(fileext = ".rds")
saveRDS.lgb.Booster(bst, file = bst_file)
bst_from_file <- readRDS.lgb.Booster(file = bst_file)
expect_identical(
bst_from_file$params
, list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
, max_bin = 17L
)
)
})
...@@ -297,13 +297,15 @@ class Booster { ...@@ -297,13 +297,15 @@ class Booster {
void ResetConfig(const char* parameters) { void ResetConfig(const char* parameters) {
UNIQUE_LOCK(mutex_) UNIQUE_LOCK(mutex_)
auto param = Config::Str2Map(parameters); auto param = Config::Str2Map(parameters);
if (param.count("num_class")) { Config new_config;
new_config.Set(param);
if (param.count("num_class") && new_config.num_class != config_.num_class) {
Log::Fatal("Cannot change num_class during training"); Log::Fatal("Cannot change num_class during training");
} }
if (param.count("boosting")) { if (param.count("boosting") && new_config.boosting != config_.boosting) {
Log::Fatal("Cannot change boosting during training"); Log::Fatal("Cannot change boosting during training");
} }
if (param.count("metric")) { if (param.count("metric") && new_config.metric != config_.metric) {
Log::Fatal("Cannot change metric during training"); Log::Fatal("Cannot change metric during training");
} }
CheckDatasetResetConfig(config_, param); CheckDatasetResetConfig(config_, param);
......
...@@ -2609,3 +2609,37 @@ class TestEngine(unittest.TestCase): ...@@ -2609,3 +2609,37 @@ class TestEngine(unittest.TestCase):
lgb_X = lgb.Dataset(X, label=y) lgb_X = lgb.Dataset(X, label=y)
lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res) lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res)
self.assertAlmostEqual(res['training']['average_precision'][-1], 1) self.assertAlmostEqual(res['training']['average_precision'][-1], 1)
def test_reset_params_works_with_metric_num_class_and_boosting(self):
X, y = load_breast_cancer(return_X_y=True)
params = {
'objective': 'multiclass',
'max_depth': 4,
'bagging_fraction': 0.8,
'metric': ['multi_logloss', 'multi_error'],
'boosting': 'gbdt',
'num_class': 5
}
dtrain = lgb.Dataset(X, y, params={"max_bin": 150})
bst = lgb.Booster(
params=params,
train_set=dtrain
)
expected_params = {
'objective': 'multiclass',
'max_depth': 4,
'bagging_fraction': 0.8,
'metric': ['multi_logloss', 'multi_error'],
'boosting': 'gbdt',
'num_class': 5,
'max_bin': 150
}
assert bst.params == expected_params
params['bagging_fraction'] = 0.9
ret_bst = bst.reset_parameter(params)
expected_params['bagging_fraction'] = 0.9
assert bst.params == expected_params
assert ret_bst.params == expected_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment