Unverified Commit 9e73cee8 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] prefer params to keyword argument in `lgb.train()` (#5007)



* [R-package] prefer params to keyword argument in lgb.train()

* make test stricter

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent cb8c61e4
...@@ -127,7 +127,7 @@ lgb.cv <- function(params = list() ...@@ -127,7 +127,7 @@ lgb.cv <- function(params = list()
params <- lgb.check.wrapper_param( params <- lgb.check.wrapper_param(
main_param_name = "objective" main_param_name = "objective"
, params = params , params = params
, alternative_kwarg_value = NULL , alternative_kwarg_value = obj
) )
params <- lgb.check.wrapper_param( params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round" main_param_name = "early_stopping_round"
...@@ -137,7 +137,7 @@ lgb.cv <- function(params = list() ...@@ -137,7 +137,7 @@ lgb.cv <- function(params = list()
early_stopping_rounds <- params[["early_stopping_round"]] early_stopping_rounds <- params[["early_stopping_round"]]
# extract any function objects passed for objective or metric # extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj) params <- lgb.check.obj(params = params)
fobj <- NULL fobj <- NULL
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
......
...@@ -95,7 +95,7 @@ lgb.train <- function(params = list(), ...@@ -95,7 +95,7 @@ lgb.train <- function(params = list(),
params <- lgb.check.wrapper_param( params <- lgb.check.wrapper_param(
main_param_name = "objective" main_param_name = "objective"
, params = params , params = params
, alternative_kwarg_value = NULL , alternative_kwarg_value = obj
) )
params <- lgb.check.wrapper_param( params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round" main_param_name = "early_stopping_round"
...@@ -105,7 +105,7 @@ lgb.train <- function(params = list(), ...@@ -105,7 +105,7 @@ lgb.train <- function(params = list(),
early_stopping_rounds <- params[["early_stopping_round"]] early_stopping_rounds <- params[["early_stopping_round"]]
# extract any function objects passed for objective or metric # extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj) params <- lgb.check.obj(params = params)
fobj <- NULL fobj <- NULL
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
......
...@@ -117,7 +117,7 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na ...@@ -117,7 +117,7 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na
} }
lgb.check.obj <- function(params, obj) { lgb.check.obj <- function(params) {
# List known objectives in a vector # List known objectives in a vector
OBJECTIVES <- c( OBJECTIVES <- c(
...@@ -158,25 +158,18 @@ lgb.check.obj <- function(params, obj) { ...@@ -158,25 +158,18 @@ lgb.check.obj <- function(params, obj) {
, "xendcg_mart" , "xendcg_mart"
) )
# Check whether the objective is empty or not, and take it from params if needed if (is.null(params$objective)) {
if (!is.null(obj)) { stop("lgb.check.obj: objective should be a character or a function")
params$objective <- obj
} }
# Check whether the objective is a character
if (is.character(params$objective)) { if (is.character(params$objective)) {
# If the objective is a character, check if it is a known objective
if (!(params$objective %in% OBJECTIVES)) { if (!(params$objective %in% OBJECTIVES)) {
stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")") stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
} }
} else if (!is.function(params$objective)) {
stop("lgb.check.obj: objective should be a character or a function")
} }
return(params) return(params)
......
...@@ -542,6 +542,34 @@ test_that("lgb.cv() respects parameter aliases for objective", { ...@@ -542,6 +542,34 @@ test_that("lgb.cv() respects parameter aliases for objective", {
expect_length(cv_bst$boosters, nfold) expect_length(cv_bst$boosters, nfold)
}) })
test_that("lgb.cv() prefers objective in params to keyword argument", {
data("EuStockMarkets")
cv_bst <- lgb.cv(
data = lgb.Dataset(
data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
, label = EuStockMarkets[, "DAX"]
)
, params = list(
application = "regression_l1"
, verbosity = VERBOSITY
)
, nrounds = 5L
, obj = "regression_l2"
)
for (bst_list in cv_bst$boosters) {
bst <- bst_list[["booster"]]
expect_equal(bst$params$objective, "regression_l1")
# NOTE: using save_model_to_string() since that is the simplest public API in the R package
# allowing access to the "objective" attribute of the Booster object on the C++ side
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
}
})
test_that("lgb.cv() respects parameter aliases for metric", { test_that("lgb.cv() respects parameter aliases for metric", {
nrounds <- 3L nrounds <- 3L
nfold <- 4L nfold <- 4L
...@@ -657,6 +685,31 @@ test_that("lgb.train() respects parameter aliases for objective", { ...@@ -657,6 +685,31 @@ test_that("lgb.train() respects parameter aliases for objective", {
expect_equal(bst$params[["objective"]], "binary") expect_equal(bst$params[["objective"]], "binary")
}) })
test_that("lgb.train() prefers objective in params to keyword argument", {
data("EuStockMarkets")
bst <- lgb.train(
data = lgb.Dataset(
data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
, label = EuStockMarkets[, "DAX"]
)
, params = list(
loss = "regression_l1"
, verbosity = VERBOSITY
)
, nrounds = 5L
, obj = "regression_l2"
)
expect_equal(bst$params$objective, "regression_l1")
# NOTE: using save_model_to_string() since that is the simplest public API in the R package
# allowing access to the "objective" attribute of the Booster object on the C++ side
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
})
test_that("lgb.train() respects parameter aliases for metric", { test_that("lgb.train() respects parameter aliases for metric", {
nrounds <- 3L nrounds <- 3L
dtrain <- lgb.Dataset( dtrain <- lgb.Dataset(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment