Unverified Commit aa647d47 authored by Michael Mayer's avatar Michael Mayer Committed by GitHub
Browse files

[R-package] added argument eval_train_metric to lgb.cv() (fixes #4911) (#4918)



* added argument eval_train_metric

* remove unnecessary whitespace

* removed further trailing whitespace

* move new argument to the last position

* update R docu

* unit tests for eval_train_metric

* Update R-package/tests/testthat/test_basic.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent ce486e5b
...@@ -43,6 +43,9 @@ CVBooster <- R6::R6Class( ...@@ -43,6 +43,9 @@ CVBooster <- R6::R6Class(
#' @param callbacks List of callback functions that are applied at each iteration. #' @param callbacks List of callback functions that are applied at each iteration.
#' @param reset_data Boolean, setting it to TRUE (not the default value) will transform the booster model #' @param reset_data Boolean, setting it to TRUE (not the default value) will transform the booster model
#' into a predictor model which frees up memory and the original datasets #' into a predictor model which frees up memory and the original datasets
#' @param eval_train_metric \code{boolean}, whether to add the cross validation results on the
#' training data. This parameter defaults to \code{FALSE}. Setting it to \code{TRUE}
#' will increase run time.
#' @inheritSection lgb_shared_params Early Stopping #' @inheritSection lgb_shared_params Early Stopping
#' @return a trained model \code{lgb.CVBooster}. #' @return a trained model \code{lgb.CVBooster}.
#' #'
...@@ -87,6 +90,7 @@ lgb.cv <- function(params = list() ...@@ -87,6 +90,7 @@ lgb.cv <- function(params = list()
, callbacks = list() , callbacks = list()
, reset_data = FALSE , reset_data = FALSE
, serializable = TRUE , serializable = TRUE
, eval_train_metric = FALSE
) { ) {
if (nrounds <= 0L) { if (nrounds <= 0L) {
...@@ -336,6 +340,9 @@ lgb.cv <- function(params = list() ...@@ -336,6 +340,9 @@ lgb.cv <- function(params = list()
} }
booster <- Booster$new(params = params, train_set = dtrain) booster <- Booster$new(params = params, train_set = dtrain)
if (isTRUE(eval_train_metric)) {
booster$add_valid(data = dtrain, name = "train")
}
booster$add_valid(data = dtest, name = "valid") booster$add_valid(data = dtest, name = "valid")
return( return(
list(booster = booster) list(booster = booster)
......
...@@ -25,7 +25,8 @@ lgb.cv( ...@@ -25,7 +25,8 @@ lgb.cv(
early_stopping_rounds = NULL, early_stopping_rounds = NULL,
callbacks = list(), callbacks = list(),
reset_data = FALSE, reset_data = FALSE,
serializable = TRUE serializable = TRUE,
eval_train_metric = FALSE
) )
} }
\arguments{ \arguments{
...@@ -120,6 +121,10 @@ into a predictor model which frees up memory and the original datasets} ...@@ -120,6 +121,10 @@ into a predictor model which frees up memory and the original datasets}
\item{serializable}{whether to make the resulting objects serializable through functions such as \item{serializable}{whether to make the resulting objects serializable through functions such as
\code{save} or \code{saveRDS} (see section "Model serialization").} \code{save} or \code{saveRDS} (see section "Model serialization").}
\item{eval_train_metric}{\code{boolean}, whether to add the cross validation results on the
training data. This parameter defaults to \code{FALSE}. Setting it to \code{TRUE}
will increase run time.}
} }
\value{ \value{
a trained model \code{lgb.CVBooster}. a trained model \code{lgb.CVBooster}.
......
...@@ -554,6 +554,45 @@ test_that("lgb.cv() respects showsd argument", { ...@@ -554,6 +554,45 @@ test_that("lgb.cv() respects showsd argument", {
expect_identical(evals_no_showsd[["eval_err"]], list()) expect_identical(evals_no_showsd[["eval_err"]], list())
}) })
test_that("lgb.cv() respects eval_train_metric argument", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(
objective = "regression"
, metric = "l2"
, min_data = 1L
)
nrounds <- 5L
set.seed(708L)
bst_train <- lgb.cv(
params = params
, data = dtrain
, nrounds = nrounds
, nfold = 3L
, showsd = FALSE
, eval_train_metric = TRUE
)
set.seed(708L)
bst_no_train <- lgb.cv(
params = params
, data = dtrain
, nrounds = nrounds
, nfold = 3L
, showsd = FALSE
, eval_train_metric = FALSE
)
expect_equal(
bst_train$record_evals[["valid"]][["l2"]]
, bst_no_train$record_evals[["valid"]][["l2"]]
)
expect_true("train" %in% names(bst_train$record_evals))
expect_false("train" %in% names(bst_no_train$record_evals))
expect_true(methods::is(bst_train$record_evals[["train"]][["l2"]][["eval"]], "list"))
expect_equal(
length(bst_train$record_evals[["train"]][["l2"]][["eval"]])
, nrounds
)
})
context("lgb.train()") context("lgb.train()")
test_that("lgb.train() works as expected with multiple eval metrics", { test_that("lgb.train() works as expected with multiple eval metrics", {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment