Unverified Commit c6245ed4 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] fixed evaluation on valids in lightgbm() (fixes #2915) (#2916)

parent eb027e1b
...@@ -73,24 +73,31 @@ lightgbm <- function(data, ...@@ -73,24 +73,31 @@ lightgbm <- function(data,
dtrain <- lgb.Dataset(data, label = label, weight = weight) dtrain <- lgb.Dataset(data, label = label, weight = weight)
} }
train_args <- list(
"params" = params
, "data" = dtrain
, "nrounds" = nrounds
, "verbose" = verbose
, "eval_freq" = eval_freq
, "early_stopping_rounds" = early_stopping_rounds
, "init_model" = init_model
, "callbacks" = callbacks
)
train_args <- append(train_args, list(...))
if (! "valids" %in% names(train_args)) {
train_args[["valids"]] <- list()
}
# Set validation as oneself # Set validation as oneself
valids <- list()
if (verbose > 0L) { if (verbose > 0L) {
valids$train <- dtrain train_args[["valids"]][["train"]] <- dtrain
} }
# Train a model using the regular way # Train a model using the regular way
bst <- lgb.train( bst <- do.call(
params = params what = lgb.train
, data = dtrain , args = train_args
, nrounds = nrounds
, valids = valids
, verbose = verbose
, eval_freq = eval_freq
, early_stopping_rounds = early_stopping_rounds
, init_model = init_model
, callbacks = callbacks
, ...
) )
# Store model under a specific name # Store model under a specific name
......
...@@ -121,6 +121,47 @@ test_that("lightgbm() rejects negative or 0 value passed to nrounds", { ...@@ -121,6 +121,47 @@ test_that("lightgbm() rejects negative or 0 value passed to nrounds", {
} }
}) })
test_that("lightgbm() performs evaluation on validation sets if they are provided", {
set.seed(708L)
dvalid1 <- lgb.Dataset(
data = train$data
, labels = train$label
)
dvalid2 <- lgb.Dataset(
data = train$data
, labels = train$label
)
nrounds <- 10L
bst <- lightgbm(
data = train$data
, label = train$label
, num_leaves = 5L
, nrounds = nrounds
, objective = "binary"
, metric = "binary_error"
, valids = list(
"valid1" = dvalid1
, "valid2" = dvalid2
)
)
expect_named(
bst$record_evals
, c("train", "valid1", "valid2", "start_iter")
, ignore.order = TRUE
, ignore.case = FALSE
)
for (valid_name in c("train", "valid1", "valid2")) {
eval_results <- bst$record_evals[[valid_name]][["binary_error"]]
expect_length(eval_results[["eval"]], nrounds)
}
expect_true(abs(bst$record_evals[["train"]][["binary_error"]][["eval"]][[1L]] - 0.02226317) < TOLERANCE)
expect_true(abs(bst$record_evals[["valid1"]][["binary_error"]][["eval"]][[1L]] - 0.4825733) < TOLERANCE)
expect_true(abs(bst$record_evals[["valid2"]][["binary_error"]][["eval"]][[1L]] - 0.4825733) < TOLERANCE)
})
context("training continuation")
test_that("training continuation works", { test_that("training continuation works", {
testthat::skip("This test is currently broken. See issue #2468 for details.") testthat::skip("This test is currently broken. See issue #2468 for details.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment