Unverified Commit aa647d47 authored by Michael Mayer's avatar Michael Mayer Committed by GitHub
Browse files

[R-package] added argument eval_train_metric to lgb.cv() (fixes #4911) (#4918)



* added argument eval_train_metric

* remove unnecessary whitespace

* removed further trailing whitespace

* move new argument to the last position

* update R docu

* unit tests for eval_train_metric

* Update R-package/tests/testthat/test_basic.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent ce486e5b
......@@ -43,6 +43,9 @@ CVBooster <- R6::R6Class(
#' @param callbacks List of callback functions that are applied at each iteration.
#' @param reset_data Boolean, setting it to TRUE (not the default value) will transform the booster model
#' into a predictor model which frees up memory and the original datasets
#' @param eval_train_metric \code{boolean}, whether to add the cross validation results on the
#' training data. This parameter defaults to \code{FALSE}. Setting it to \code{TRUE}
#' will increase run time.
#' @inheritSection lgb_shared_params Early Stopping
#' @return a trained model \code{lgb.CVBooster}.
#'
......@@ -87,6 +90,7 @@ lgb.cv <- function(params = list()
, callbacks = list()
, reset_data = FALSE
, serializable = TRUE
, eval_train_metric = FALSE
) {
if (nrounds <= 0L) {
......@@ -336,6 +340,9 @@ lgb.cv <- function(params = list()
}
booster <- Booster$new(params = params, train_set = dtrain)
if (isTRUE(eval_train_metric)) {
booster$add_valid(data = dtrain, name = "train")
}
booster$add_valid(data = dtest, name = "valid")
return(
list(booster = booster)
......
......@@ -25,7 +25,8 @@ lgb.cv(
early_stopping_rounds = NULL,
callbacks = list(),
reset_data = FALSE,
serializable = TRUE
serializable = TRUE,
eval_train_metric = FALSE
)
}
\arguments{
......@@ -120,6 +121,10 @@ into a predictor model which frees up memory and the original datasets}
\item{serializable}{whether to make the resulting objects serializable through functions such as
\code{save} or \code{saveRDS} (see section "Model serialization").}
\item{eval_train_metric}{\code{boolean}, whether to add the cross validation results on the
training data. This parameter defaults to \code{FALSE}. Setting it to \code{TRUE}
will increase run time.}
}
\value{
a trained model \code{lgb.CVBooster}.
......
......@@ -554,6 +554,45 @@ test_that("lgb.cv() respects showsd argument", {
expect_identical(evals_no_showsd[["eval_err"]], list())
})
test_that("lgb.cv() respects eval_train_metric argument", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
)
nrounds <- 5L
set.seed(708L)
bst_train <- lgb.cv(
params = params
, data = dtrain
, nrounds = nrounds
, nfold = 3L
, showsd = FALSE
, eval_train_metric = TRUE
)
set.seed(708L)
bst_no_train <- lgb.cv(
params = params
, data = dtrain
, nrounds = nrounds
, nfold = 3L
, showsd = FALSE
, eval_train_metric = FALSE
)
expect_equal(
bst_train$record_evals[["valid"]][["l2"]]
, bst_no_train$record_evals[["valid"]][["l2"]]
)
expect_true("train" %in% names(bst_train$record_evals))
expect_false("train" %in% names(bst_no_train$record_evals))
expect_true(methods::is(bst_train$record_evals[["train"]][["l2"]][["eval"]], "list"))
expect_equal(
length(bst_train$record_evals[["train"]][["l2"]][["eval"]])
, nrounds
)
})
context("lgb.train()")
test_that("lgb.train() works as expected with multiple eval metrics", {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment