Unverified Commit ac821f0c authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] respect aliases for objective and metric and lgb.train() and lgb.cv() (#4913)

* [R-package] respect aliases for objective and metric

* move eval code closer to eval processing

* remove unnecessary diff

* Update R-package/tests/testthat/test_basic.R
parent af5b40e1
...@@ -105,12 +105,6 @@ lgb.cv <- function(params = list() ...@@ -105,12 +105,6 @@ lgb.cv <- function(params = list()
data <- lgb.Dataset(data = data, label = label) data <- lgb.Dataset(data = data, label = label)
} }
# Setup temporary variables
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.eval(params = params, eval = eval)
fobj <- NULL
eval_functions <- list(NULL)
# set some parameters, resolving the way they were passed in with other parameters # set some parameters, resolving the way they were passed in with other parameters
# in `params`. # in `params`.
# this ensures that the model stored with Booster$save() correctly represents # this ensures that the model stored with Booster$save() correctly represents
...@@ -125,6 +119,16 @@ lgb.cv <- function(params = list() ...@@ -125,6 +119,16 @@ lgb.cv <- function(params = list()
, params = params , params = params
, alternative_kwarg_value = nrounds , alternative_kwarg_value = nrounds
) )
params <- lgb.check.wrapper_param(
main_param_name = "metric"
, params = params
, alternative_kwarg_value = NULL
)
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
)
params <- lgb.check.wrapper_param( params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round" main_param_name = "early_stopping_round"
, params = params , params = params
...@@ -132,7 +136,9 @@ lgb.cv <- function(params = list() ...@@ -132,7 +136,9 @@ lgb.cv <- function(params = list()
) )
early_stopping_rounds <- params[["early_stopping_round"]] early_stopping_rounds <- params[["early_stopping_round"]]
# Check for objective (function or not) # extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
fobj <- NULL
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
...@@ -142,6 +148,8 @@ lgb.cv <- function(params = list() ...@@ -142,6 +148,8 @@ lgb.cv <- function(params = list()
# (for backwards compatibility). If it is a list of functions, store # (for backwards compatibility). If it is a list of functions, store
# all of them. This makes it possible to pass any mix of strings like "auc" # all of them. This makes it possible to pass any mix of strings like "auc"
# and custom functions to eval # and custom functions to eval
params <- lgb.check.eval(params = params, eval = eval)
eval_functions <- list(NULL)
if (is.function(eval)) { if (is.function(eval)) {
eval_functions <- list(eval) eval_functions <- list(eval)
} }
......
...@@ -73,12 +73,6 @@ lgb.train <- function(params = list(), ...@@ -73,12 +73,6 @@ lgb.train <- function(params = list(),
} }
} }
# Setup temporary variables
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.eval(params = params, eval = eval)
fobj <- NULL
eval_functions <- list(NULL)
# set some parameters, resolving the way they were passed in with other parameters # set some parameters, resolving the way they were passed in with other parameters
# in `params`. # in `params`.
# this ensures that the model stored with Booster$save() correctly represents # this ensures that the model stored with Booster$save() correctly represents
...@@ -93,6 +87,16 @@ lgb.train <- function(params = list(), ...@@ -93,6 +87,16 @@ lgb.train <- function(params = list(),
, params = params , params = params
, alternative_kwarg_value = nrounds , alternative_kwarg_value = nrounds
) )
params <- lgb.check.wrapper_param(
main_param_name = "metric"
, params = params
, alternative_kwarg_value = NULL
)
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
)
params <- lgb.check.wrapper_param( params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round" main_param_name = "early_stopping_round"
, params = params , params = params
...@@ -100,7 +104,9 @@ lgb.train <- function(params = list(), ...@@ -100,7 +104,9 @@ lgb.train <- function(params = list(),
) )
early_stopping_rounds <- params[["early_stopping_round"]] early_stopping_rounds <- params[["early_stopping_round"]]
# Check for objective (function or not) # extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
fobj <- NULL
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
params$objective <- "NONE" params$objective <- "NONE"
...@@ -110,6 +116,8 @@ lgb.train <- function(params = list(), ...@@ -110,6 +116,8 @@ lgb.train <- function(params = list(),
# (for backwards compatibility). If it is a list of functions, store # (for backwards compatibility). If it is a list of functions, store
# all of them. This makes it possible to pass any mix of strings like "auc" # all of them. This makes it possible to pass any mix of strings like "auc"
# and custom functions to eval # and custom functions to eval
params <- lgb.check.eval(params = params, eval = eval)
eval_functions <- list(NULL)
if (is.function(eval)) { if (is.function(eval)) {
eval_functions <- list(eval) eval_functions <- list(eval)
} }
......
...@@ -547,6 +547,52 @@ test_that("lgb.cv() respects showsd argument", { ...@@ -547,6 +547,52 @@ test_that("lgb.cv() respects showsd argument", {
expect_identical(evals_no_showsd[["eval_err"]], list()) expect_identical(evals_no_showsd[["eval_err"]], list())
}) })
test_that("lgb.cv() respects parameter aliases for objective", {
nrounds <- 3L
nfold <- 4L
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
cv_bst <- lgb.cv(
data = dtrain
, params = list(
num_leaves = 5L
, application = "binary"
, num_iterations = nrounds
)
, nfold = nfold
)
expect_equal(cv_bst$best_iter, nrounds)
expect_named(cv_bst$record_evals[["valid"]], "binary_logloss")
expect_length(cv_bst$record_evals[["valid"]][["binary_logloss"]][["eval"]], nrounds)
expect_length(cv_bst$boosters, nfold)
})
test_that("lgb.cv() respects parameter aliases for metric", {
nrounds <- 3L
nfold <- 4L
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
cv_bst <- lgb.cv(
data = dtrain
, params = list(
num_leaves = 5L
, objective = "binary"
, num_iterations = nrounds
, metric_types = c("auc", "binary_logloss")
)
, nfold = nfold
)
expect_equal(cv_bst$best_iter, nrounds)
expect_named(cv_bst$record_evals[["valid"]], c("auc", "binary_logloss"))
expect_length(cv_bst$record_evals[["valid"]][["binary_logloss"]][["eval"]], nrounds)
expect_length(cv_bst$record_evals[["valid"]][["auc"]][["eval"]], nrounds)
expect_length(cv_bst$boosters, nfold)
})
test_that("lgb.cv() respects eval_train_metric argument", { test_that("lgb.cv() respects eval_train_metric argument", {
dtrain <- lgb.Dataset(train$data, label = train$label) dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list( params <- list(
...@@ -616,6 +662,53 @@ test_that("lgb.train() works as expected with multiple eval metrics", { ...@@ -616,6 +662,53 @@ test_that("lgb.train() works as expected with multiple eval metrics", {
) )
}) })
test_that("lgb.train() respects parameter aliases for objective", {
nrounds <- 3L
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
bst <- lgb.train(
data = dtrain
, params = list(
num_leaves = 5L
, application = "binary"
, num_iterations = nrounds
)
, valids = list(
"the_training_data" = dtrain
)
)
expect_named(bst$record_evals[["the_training_data"]], "binary_logloss")
expect_length(bst$record_evals[["the_training_data"]][["binary_logloss"]][["eval"]], nrounds)
expect_equal(bst$params[["objective"]], "binary")
})
test_that("lgb.train() respects parameter aliases for metric", {
nrounds <- 3L
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
bst <- lgb.train(
data = dtrain
, params = list(
num_leaves = 5L
, objective = "binary"
, num_iterations = nrounds
, metric_types = c("auc", "binary_logloss")
)
, valids = list(
"train" = dtrain
)
)
record_results <- bst$record_evals[["train"]]
expect_equal(sort(names(record_results)), c("auc", "binary_logloss"))
expect_length(record_results[["auc"]][["eval"]], nrounds)
expect_length(record_results[["binary_logloss"]][["eval"]], nrounds)
expect_equal(bst$params[["metric"]], list("auc", "binary_logloss"))
})
test_that("lgb.train() rejects negative or 0 value passed to nrounds", { test_that("lgb.train() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label) dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list( params <- list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment