Unverified Commit 675b552d authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] fix best_score using custom evaluation (fixes #3112) (#3117)

parent 9d433033
...@@ -298,7 +298,15 @@ lgb.train <- function(params = list(), ...@@ -298,7 +298,15 @@ lgb.train <- function(params = list(),
# When early stopping is not activated, we compute the best iteration / score ourselves by # When early stopping is not activated, we compute the best iteration / score ourselves by
# selecting the first metric and the first dataset # selecting the first metric and the first dataset
if (record && length(non_train_valid_names) > 0L && is.na(env$best_score)) { if (record && length(non_train_valid_names) > 0L && is.na(env$best_score)) {
# when using a custom eval function, the metric name is returned from the
# function, so figure it out from record_evals
if (!is.null(feval)) {
first_metric <- names(booster$record_evals[[first_valid_name]])[1L]
} else {
first_metric <- booster$.__enclos_env__$private$eval_names[1L] first_metric <- booster$.__enclos_env__$private$eval_names[1L]
}
.find_best <- which.min .find_best <- which.min
if (isTRUE(env$eval_list[[1L]]$higher_better[1L])) { if (isTRUE(env$eval_list[[1L]]$higher_better[1L])) {
.find_best <- which.max .find_best <- which.max
......
...@@ -6,6 +6,8 @@ dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label) ...@@ -6,6 +6,8 @@ dtrain <- lgb.Dataset(agaricus.train$data, label = agaricus.train$label)
dtest <- lgb.Dataset(agaricus.test$data, label = agaricus.test$label) dtest <- lgb.Dataset(agaricus.test$data, label = agaricus.test$label)
watchlist <- list(eval = dtest, train = dtrain) watchlist <- list(eval = dtest, train = dtrain)
TOLERANCE <- 1e-6
logregobj <- function(preds, dtrain) { logregobj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label") labels <- getinfo(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds)) preds <- 1.0 / (1.0 + exp(-preds))
...@@ -41,3 +43,27 @@ test_that("custom objective works", { ...@@ -41,3 +43,27 @@ test_that("custom objective works", {
bst <- lgb.train(param, dtrain, num_round, watchlist, eval = evalerror) bst <- lgb.train(param, dtrain, num_round, watchlist, eval = evalerror)
expect_false(is.null(bst$record_evals)) expect_false(is.null(bst$record_evals))
}) })
test_that("using a custom objective, custom eval, and no other metrics works", {
set.seed(708L)
bst <- lgb.train(
params = list(
num_leaves = 8L
, learning_rate = 1.0
)
, data = dtrain
, nrounds = 4L
, valids = watchlist
, obj = logregobj
, eval = evalerror
)
expect_false(is.null(bst$record_evals))
expect_equal(bst$best_iter, 4L)
expect_true(abs(bst$best_score - 0.000621) < TOLERANCE)
eval_results <- bst$eval_valid(feval = evalerror)[[1L]]
expect_true(eval_results[["data_name"]] == "eval")
expect_true(abs(eval_results[["value"]] - 0.0006207325) < TOLERANCE)
expect_true(eval_results[["name"]] == "error")
expect_false(eval_results[["higher_better"]])
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment