Unverified Commit 8d43356a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] fixed handling of multiple evaluation metrics (fixes #2913) (#2914)

parent 651375d5
...@@ -136,8 +136,17 @@ lgb.params2str <- function(params, ...) { ...@@ -136,8 +136,17 @@ lgb.params2str <- function(params, ...) {
# Perform key value join # Perform key value join
for (key in names(params)) { for (key in names(params)) {
# Join multi value first # If a parameter has multiple values, join those values together with commas.
val <- paste0(format(params[[key]], scientific = FALSE), collapse = ",") # trimws() is necessary because format() will pad to make strings the same width
val <- paste0(
trimws(
format(
x = params[[key]]
, scientific = FALSE
)
)
, collapse = ","
)
if (nchar(val) <= 0L) next # Skip join if (nchar(val) <= 0L) next # Skip join
# Join key value # Join key value
...@@ -148,17 +157,12 @@ lgb.params2str <- function(params, ...) { ...@@ -148,17 +157,12 @@ lgb.params2str <- function(params, ...) {
# Check ret length # Check ret length
if (length(ret) == 0L) { if (length(ret) == 0L) {
return(lgb.c_str(""))
# Return empty string
lgb.c_str("")
} else {
# Return string separated by a space per element
lgb.c_str(paste0(ret, collapse = " "))
} }
# Return string separated by a space per element
return(lgb.c_str(paste0(ret, collapse = " ")))
} }
lgb.c_str <- function(x) { lgb.c_str <- function(x) {
......
...@@ -58,6 +58,7 @@ test_that("train and predict softmax", { ...@@ -58,6 +58,7 @@ test_that("train and predict softmax", {
test_that("use of multiple eval metrics works", { test_that("use of multiple eval metrics works", {
metrics <- list("binary_error", "auc", "binary_logloss")
bst <- lightgbm( bst <- lightgbm(
data = train$data data = train$data
, label = train$label , label = train$label
...@@ -65,9 +66,15 @@ test_that("use of multiple eval metrics works", { ...@@ -65,9 +66,15 @@ test_that("use of multiple eval metrics works", {
, learning_rate = 1.0 , learning_rate = 1.0
, nrounds = 10L , nrounds = 10L
, objective = "binary" , objective = "binary"
, metric = list("binary_error", "auc", "binary_logloss") , metric = metrics
) )
expect_false(is.null(bst$record_evals)) expect_false(is.null(bst$record_evals))
expect_named(
bst$record_evals[["train"]]
, unlist(metrics)
, ignore.order = FALSE
, ignore.case = FALSE
)
}) })
test_that("lgb.Booster.upper_bound() and lgb.Booster.lower_bound() work as expected for binary classification", { test_that("lgb.Booster.upper_bound() and lgb.Booster.lower_bound() work as expected for binary classification", {
...@@ -204,6 +211,35 @@ test_that("lgb.cv() throws an informative error is 'data' is not an lgb.Dataset ...@@ -204,6 +211,35 @@ test_that("lgb.cv() throws an informative error is 'data' is not an lgb.Dataset
context("lgb.train()") context("lgb.train()")
test_that("lgb.train() works as expected with multiple eval metrics", {
metrics <- c("binary_error", "auc", "binary_logloss")
bst <- lgb.train(
data = lgb.Dataset(
train$data
, label = train$label
)
, learning_rate = 1.0
, nrounds = 10L
, params = list(
objective = "binary"
, metric = metrics
)
, valids = list(
"train" = lgb.Dataset(
train$data
, label = train$label
)
)
)
expect_false(is.null(bst$record_evals))
expect_named(
bst$record_evals[["train"]]
, unlist(metrics)
, ignore.order = FALSE
, ignore.case = FALSE
)
})
test_that("lgb.train() rejects negative or 0 value passed to nrounds", { test_that("lgb.train() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label) dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2,l1") params <- list(objective = "regression", metric = "l2,l1")
......
...@@ -19,3 +19,32 @@ test_that("lgb.check.r6.class() should correctly identify lgb.Dataset", { ...@@ -19,3 +19,32 @@ test_that("lgb.check.r6.class() should correctly identify lgb.Dataset", {
expect_false(lgb.check.r6.class(ds, "lgb.Predictor")) expect_false(lgb.check.r6.class(ds, "lgb.Predictor"))
expect_false(lgb.check.r6.class(ds, "lgb.Booster")) expect_false(lgb.check.r6.class(ds, "lgb.Booster"))
}) })
context("lgb.params2str")
test_that("lgb.params2str() works as expected for empty lists", {
out_str <- lgb.params2str(
params = list()
)
expect_identical(class(out_str), "raw")
expect_equal(out_str, lgb.c_str(""))
})
test_that("lgb.params2str() works as expected for a key in params with multiple different-length elements", {
metrics <- c("a", "ab", "abc", "abcdefg")
params <- list(
objective = "magic"
, metric = metrics
, nrounds = 10L
, learning_rate = 0.0000001
)
out_str <- lgb.params2str(
params = params
)
expect_identical(class(out_str), "raw")
out_as_char <- rawToChar(out_str)
expect_identical(
out_as_char
, "objective=magic metric=a,ab,abc,abcdefg nrounds=10 learning_rate=0.0000001"
)
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment