Unverified Commit 9e73cee8 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] prefer params to keyword argument in `lgb.train()` (#5007)



* [R-package] prefer params to keyword argument in lgb.train()

* make test stricter

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent cb8c61e4
......@@ -127,7 +127,7 @@ lgb.cv <- function(params = list()
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
, alternative_kwarg_value = obj
)
params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round"
......@@ -137,7 +137,7 @@ lgb.cv <- function(params = list()
early_stopping_rounds <- params[["early_stopping_round"]]
# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.obj(params = params)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
......
......@@ -95,7 +95,7 @@ lgb.train <- function(params = list(),
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
, alternative_kwarg_value = obj
)
params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round"
......@@ -105,7 +105,7 @@ lgb.train <- function(params = list(),
early_stopping_rounds <- params[["early_stopping_round"]]
# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.obj(params = params)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
......
......@@ -117,7 +117,7 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na
}
lgb.check.obj <- function(params, obj) {
lgb.check.obj <- function(params) {
# List known objectives in a vector
OBJECTIVES <- c(
......@@ -158,25 +158,18 @@ lgb.check.obj <- function(params, obj) {
, "xendcg_mart"
)
# Check whether the objective is empty or not, and take it from params if needed
if (!is.null(obj)) {
params$objective <- obj
if (is.null(params$objective)) {
stop("lgb.check.obj: objective should be a character or a function")
}
# Check whether the objective is a character
if (is.character(params$objective)) {
# If the objective is a character, check if it is a known objective
if (!(params$objective %in% OBJECTIVES)) {
stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
}
} else if (!is.function(params$objective)) {
stop("lgb.check.obj: objective should be a character or a function")
}
return(params)
......
......@@ -542,6 +542,34 @@ test_that("lgb.cv() respects parameter aliases for objective", {
expect_length(cv_bst$boosters, nfold)
})
test_that("lgb.cv() prefers objective in params to keyword argument", {
data("EuStockMarkets")
cv_bst <- lgb.cv(
data = lgb.Dataset(
data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
, label = EuStockMarkets[, "DAX"]
)
, params = list(
application = "regression_l1"
, verbosity = VERBOSITY
)
, nrounds = 5L
, obj = "regression_l2"
)
for (bst_list in cv_bst$boosters) {
bst <- bst_list[["booster"]]
expect_equal(bst$params$objective, "regression_l1")
# NOTE: using save_model_to_string() since that is the simplest public API in the R package
# allowing access to the "objective" attribute of the Booster object on the C++ side
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
}
})
test_that("lgb.cv() respects parameter aliases for metric", {
nrounds <- 3L
nfold <- 4L
......@@ -657,6 +685,31 @@ test_that("lgb.train() respects parameter aliases for objective", {
expect_equal(bst$params[["objective"]], "binary")
})
test_that("lgb.train() prefers objective in params to keyword argument", {
data("EuStockMarkets")
bst <- lgb.train(
data = lgb.Dataset(
data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
, label = EuStockMarkets[, "DAX"]
)
, params = list(
loss = "regression_l1"
, verbosity = VERBOSITY
)
, nrounds = 5L
, obj = "regression_l2"
)
expect_equal(bst$params$objective, "regression_l1")
# NOTE: using save_model_to_string() since that is the simplest public API in the R package
# allowing access to the "objective" attribute of the Booster object on the C++ side
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
})
test_that("lgb.train() respects parameter aliases for metric", {
nrounds <- 3L
dtrain <- lgb.Dataset(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment