Unverified Commit d9064282 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] moved parameter validations up earlier in function calls (#2663)

parent 08fd53cd
...@@ -29,11 +29,10 @@ cb.reset.parameters <- function(new_params) { ...@@ -29,11 +29,10 @@ cb.reset.parameters <- function(new_params) {
# Run some checks in the beginning # Run some checks in the beginning
init <- function(env) { init <- function(env) {
# Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1L
# Check for model environment # Check for model environment
if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) } if (is.null(env$model)) {
stop("Env should have a ", sQuote("model"))
}
# Some parameters are not allowed to be changed, # Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos # since changing them would simply wreck some chaos
...@@ -50,6 +49,9 @@ cb.reset.parameters <- function(new_params) { ...@@ -50,6 +49,9 @@ cb.reset.parameters <- function(new_params) {
) )
} }
# Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1L
# Check parameter names # Check parameter names
for (n in pnames) { for (n in pnames) {
...@@ -285,14 +287,14 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) { ...@@ -285,14 +287,14 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Initialization function # Initialization function
init <- function(env) { init <- function(env) {
# Store evaluation length
eval_len <<- length(env$eval_list)
# Early stopping cannot work without metrics # Early stopping cannot work without metrics
if (eval_len == 0L) { if (length(env$eval_list) == 0L) {
stop("For early stopping, valids must have at least one element") stop("For early stopping, valids must have at least one element")
} }
# Store evaluation length
eval_len <<- length(env$eval_list)
# Check if verbose or not # Check if verbose or not
if (isTRUE(verbose)) { if (isTRUE(verbose)) {
cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "") cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
......
...@@ -781,15 +781,21 @@ lgb.load <- function(filename = NULL, model_str = NULL) { ...@@ -781,15 +781,21 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
} }
# Return new booster # Return new booster
if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename") if (!is.null(filename) && !file.exists(filename)) {
if (!is.null(filename)) return(invisible(Booster$new(modelfile = filename))) stop("lgb.load: file does not exist for supplied filename")
}
if (!is.null(filename)) {
return(invisible(Booster$new(modelfile = filename)))
}
# Load from model_str # Load from model_str
if (!is.null(model_str) && !is.character(model_str)) { if (!is.null(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be character") stop("lgb.load: model_str should be character")
} }
# Return new booster # Return new booster
if (!is.null(model_str)) return(invisible(Booster$new(model_str = model_str))) if (!is.null(model_str)) {
return(invisible(Booster$new(model_str = model_str)))
}
} }
...@@ -831,8 +837,8 @@ lgb.save <- function(booster, filename, num_iteration = NULL) { ...@@ -831,8 +837,8 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
} }
# Check if file name is character # Check if file name is character
if (!is.character(filename)) { if (!(is.character(filename) && length(filename) == 1L)) {
stop("lgb.save: filename should be a character") stop("lgb.save: filename should be a string")
} }
# Store booster # Store booster
......
...@@ -32,6 +32,14 @@ Dataset <- R6::R6Class( ...@@ -32,6 +32,14 @@ Dataset <- R6::R6Class(
info = list(), info = list(),
...) { ...) {
# validate inputs early to avoid unnecessary computation
if (!(is.null(reference) || lgb.check.r6.class(reference, "lgb.Dataset"))) {
stop("lgb.Dataset: If provided, reference must be a ", sQuote("lgb.Dataset"))
}
if (!(is.null(predictor) || lgb.check.r6.class(predictor, "lgb.Predictor"))) {
stop("lgb.Dataset: If provided, predictor must be a ", sQuote("lgb.Predictor"))
}
# Check for additional parameters # Check for additional parameters
additional_params <- list(...) additional_params <- list(...)
...@@ -56,20 +64,6 @@ Dataset <- R6::R6Class( ...@@ -56,20 +64,6 @@ Dataset <- R6::R6Class(
} }
# Check for dataset reference
if (!is.null(reference)) {
if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
stop("lgb.Dataset: Can only use ", sQuote("lgb.Dataset"), " as reference")
}
}
# Check for predictor reference
if (!is.null(predictor)) {
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
stop("lgb.Dataset: Only can use ", sQuote("lgb.Predictor"), " as predictor")
}
}
# Check for matrix format # Check for matrix format
if (is.matrix(data)) { if (is.matrix(data)) {
# Check whether matrix is the correct type first ("double") # Check whether matrix is the correct type first ("double")
......
...@@ -22,7 +22,7 @@ CVBooster <- R6::R6Class( ...@@ -22,7 +22,7 @@ CVBooster <- R6::R6Class(
#' @description Cross validation logic used by LightGBM #' @description Cross validation logic used by LightGBM
#' @inheritParams lgb_shared_params #' @inheritParams lgb_shared_params
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples. #' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label vector of response values. Should be provided only when data is an R-matrix. #' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset #' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function. Examples include #' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber}, #' \code{regression}, \code{regression_l1}, \code{huber},
...@@ -95,6 +95,19 @@ lgb.cv <- function(params = list() ...@@ -95,6 +95,19 @@ lgb.cv <- function(params = list()
, ... , ...
) { ) {
# validate parameters
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
# If 'data' is not an lgb.Dataset, try to construct one using 'label'
if (!lgb.is.Dataset(data)) {
if (is.null(label)) {
stop("'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'")
}
data <- lgb.Dataset(data, label = label)
}
# Setup temporary variables # Setup temporary variables
params <- append(params, list(...)) params <- append(params, list(...))
params$verbose <- verbose params$verbose <- verbose
...@@ -103,10 +116,6 @@ lgb.cv <- function(params = list() ...@@ -103,10 +116,6 @@ lgb.cv <- function(params = list()
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
# Check for objective (function or not) # Check for objective (function or not)
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
...@@ -141,14 +150,6 @@ lgb.cv <- function(params = list() ...@@ -141,14 +150,6 @@ lgb.cv <- function(params = list()
end_iteration <- begin_iteration + nrounds - 1L end_iteration <- begin_iteration + nrounds - 1L
} }
# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
if (is.null(label)) {
stop("Labels must be provided for lgb.cv")
}
data <- lgb.Dataset(data, label = label)
}
# Check for weights # Check for weights
if (!is.null(weight)) { if (!is.null(weight)) {
data$setinfo("weight", weight) data$setinfo("weight", weight)
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
lgb.importance <- function(model, percentage = TRUE) { lgb.importance <- function(model, percentage = TRUE) {
# Check if model is a lightgbm model # Check if model is a lightgbm model
if (!inherits(model, "lgb.Booster")) { if (!lgb.is.Booster(model)) {
stop("'model' has to be an object of class lgb.Booster") stop("'model' has to be an object of class lgb.Booster")
} }
......
...@@ -65,6 +65,23 @@ lgb.train <- function(params = list(), ...@@ -65,6 +65,23 @@ lgb.train <- function(params = list(),
reset_data = FALSE, reset_data = FALSE,
...) { ...) {
# validate inputs early to avoid unnecessary computation
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data must be an lgb.Dataset instance")
}
if (length(valids) > 0L) {
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1L)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements")
}
evnames <- names(valids)
if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of valids must have a name")
}
}
# Setup temporary variables # Setup temporary variables
additional_params <- list(...) additional_params <- list(...)
params <- append(params, additional_params) params <- append(params, additional_params)
...@@ -74,10 +91,6 @@ lgb.train <- function(params = list(), ...@@ -74,10 +91,6 @@ lgb.train <- function(params = list(),
fobj <- NULL fobj <- NULL
feval <- NULL feval <- NULL
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
# Check for objective (function or not) # Check for objective (function or not)
if (is.function(params$objective)) { if (is.function(params$objective)) {
fobj <- params$objective fobj <- params$objective
...@@ -112,30 +125,6 @@ lgb.train <- function(params = list(), ...@@ -112,30 +125,6 @@ lgb.train <- function(params = list(),
end_iteration <- begin_iteration + nrounds - 1L end_iteration <- begin_iteration + nrounds - 1L
} }
# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object")
}
# Check for validation dataset type correctness
if (length(valids) > 0L) {
# One or more validation dataset
# Check for list as input and type correctness by object
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1L)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements")
}
# Attempt to get names
evnames <- names(valids)
# Check for names existance
if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of the valids must have a name tag")
}
}
# Update parameters with parsed parameters # Update parameters with parsed parameters
data$update_params(params) data$update_params(params)
......
#' @name lgb_shared_params #' @name lgb_shared_params
#' @title Shared parameter docs #' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm} #' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @param callbacks List of callback functions that are applied at each iteration. #' @param callbacks list of callback functions
#' @param data a \code{lgb.Dataset} object, used for training #' List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
#' may allow you to pass other types of data like \code{matrix} and then separately supply
#' \code{label} as a keyword argument.
#' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data #' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data
#' and one metric. If there's more than one, will check all of them #' and one metric. If there's more than one, will check all of them
#' except the training data. Returns the model with (best_iter + early_stopping_rounds). #' except the training data. Returns the model with (best_iter + early_stopping_rounds).
...@@ -57,11 +60,14 @@ lightgbm <- function(data, ...@@ -57,11 +60,14 @@ lightgbm <- function(data,
callbacks = list(), callbacks = list(),
...) { ...) {
# Set data to a temporary variable # validate inputs early to avoid unnecessary computation
dtrain <- data
if (nrounds <= 0L) { if (nrounds <= 0L) {
stop("nrounds should be greater than zero") stop("nrounds should be greater than zero")
} }
# Set data to a temporary variable
dtrain <- data
# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
if (!lgb.is.Dataset(dtrain)) { if (!lgb.is.Dataset(dtrain)) {
dtrain <- lgb.Dataset(data, label = label, weight = weight) dtrain <- lgb.Dataset(data, label = label, weight = weight)
......
...@@ -31,13 +31,15 @@ lgb.cv( ...@@ -31,13 +31,15 @@ lgb.cv(
\arguments{ \arguments{
\item{params}{List of parameters} \item{params}{List of parameters}
\item{data}{a \code{lgb.Dataset} object, used for training} \item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{nrounds}{number of training rounds} \item{nrounds}{number of training rounds}
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.} \item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
\item{label}{vector of response values. Should be provided only when data is an R-matrix.} \item{label}{Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}}
\item{weight}{vector of response values. If not NULL, will set to dataset} \item{weight}{vector of response values. If not NULL, will set to dataset}
......
...@@ -26,7 +26,9 @@ lgb.train( ...@@ -26,7 +26,9 @@ lgb.train(
\arguments{ \arguments{
\item{params}{List of parameters} \item{params}{List of parameters}
\item{data}{a \code{lgb.Dataset} object, used for training} \item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{nrounds}{number of training rounds} \item{nrounds}{number of training rounds}
......
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
\arguments{ \arguments{
\item{callbacks}{List of callback functions that are applied at each iteration.} \item{callbacks}{List of callback functions that are applied at each iteration.}
\item{data}{a \code{lgb.Dataset} object, used for training} \item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{early_stopping_rounds}{int. Activates early stopping. Requires at least one validation data \item{early_stopping_rounds}{int. Activates early stopping. Requires at least one validation data
and one metric. If there's more than one, will check all of them and one metric. If there's more than one, will check all of them
......
...@@ -20,7 +20,9 @@ lightgbm( ...@@ -20,7 +20,9 @@ lightgbm(
) )
} }
\arguments{ \arguments{
\item{data}{a \code{lgb.Dataset} object, used for training} \item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{label}{Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}} \item{label}{Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}}
......
context("basic functions") context("lightgbm()")
data(agaricus.train, package = "lightgbm") data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm") data(agaricus.test, package = "lightgbm")
...@@ -70,6 +70,20 @@ test_that("use of multiple eval metrics works", { ...@@ -70,6 +70,20 @@ test_that("use of multiple eval metrics works", {
expect_false(is.null(bst$record_evals)) expect_false(is.null(bst$record_evals))
}) })
test_that("lightgbm() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2,l1")
for (nround_value in c(-10L, 0L)) {
expect_error({
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = nround_value
)
}, "nrounds should be greater than zero")
}
})
test_that("training continuation works", { test_that("training continuation works", {
testthat::skip("This test is currently broken. See issue #2468 for details.") testthat::skip("This test is currently broken. See issue #2468 for details.")
...@@ -103,6 +117,7 @@ test_that("training continuation works", { ...@@ -103,6 +117,7 @@ test_that("training continuation works", {
expect_lt(abs(err_bst - err_bst2), 0.01) expect_lt(abs(err_bst - err_bst2), 0.01)
}) })
context("lgb.cv()")
test_that("cv works", { test_that("cv works", {
dtrain <- lgb.Dataset(train$data, label = train$label) dtrain <- lgb.Dataset(train$data, label = train$label)
...@@ -118,3 +133,122 @@ test_that("cv works", { ...@@ -118,3 +133,122 @@ test_that("cv works", {
) )
expect_false(is.null(bst$record_evals)) expect_false(is.null(bst$record_evals))
}) })
test_that("lgb.cv() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2,l1")
for (nround_value in c(-10L, 0L)) {
expect_error({
bst <- lgb.cv(
params
, dtrain
, nround_value
, nfold = 5L
, min_data = 1L
)
}, "nrounds should be greater than zero")
}
})
test_that("lgb.cv() throws an informative error is 'data' is not an lgb.Dataset and labels are not given", {
bad_values <- list(
4L
, "hello"
, list(a = TRUE, b = seq_len(10L))
, data.frame(x = seq_len(5L), y = seq_len(5L))
, data.table::data.table(x = seq_len(5L), y = seq_len(5L))
, matrix(data = seq_len(10L), 2L, 5L)
)
for (val in bad_values) {
expect_error({
bst <- lgb.cv(
params = list(objective = "regression", metric = "l2,l1")
, data = val
, 10L
, nfold = 5L
, min_data = 1L
)
}, regexp = "'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'", fixed = TRUE)
}
})
context("lgb.train()")
test_that("lgb.train() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2,l1")
for (nround_value in c(-10L, 0L)) {
expect_error({
bst <- lgb.train(
params
, dtrain
, nround_value
)
}, "nrounds should be greater than zero")
}
})
test_that("lgb.train() throws an informative error if 'data' is not an lgb.Dataset", {
bad_values <- list(
4L
, "hello"
, list(a = TRUE, b = seq_len(10L))
, data.frame(x = seq_len(5L), y = seq_len(5L))
, data.table::data.table(x = seq_len(5L), y = seq_len(5L))
, matrix(data = seq_len(10L), 2L, 5L)
)
for (val in bad_values) {
expect_error({
bst <- lgb.train(
params = list(objective = "regression", metric = "l2,l1")
, data = val
, 10L
)
}, regexp = "data must be an lgb.Dataset instance", fixed = TRUE)
}
})
test_that("lgb.train() throws an informative error if 'valids' is not a list of lgb.Dataset objects", {
valids <- list(
"valid1" = data.frame(x = rnorm(5L), y = rnorm(5L))
, "valid2" = data.frame(x = rnorm(5L), y = rnorm(5L))
)
expect_error({
bst <- lgb.train(
params = list(objective = "regression", metric = "l2,l1")
, data = lgb.Dataset(train$data, label = train$label)
, 10L
, valids = valids
)
}, regexp = "valids must be a list of lgb.Dataset elements")
})
test_that("lgb.train() errors if 'valids' is a list of lgb.Dataset objects but some do not have names", {
valids <- list(
"valid1" = lgb.Dataset(matrix(rnorm(10L), 5L, 2L))
, lgb.Dataset(matrix(rnorm(10L), 2L, 5L))
)
expect_error({
bst <- lgb.train(
params = list(objective = "regression", metric = "l2,l1")
, data = lgb.Dataset(train$data, label = train$label)
, 10L
, valids = valids
)
}, regexp = "each element of valids must have a name")
})
test_that("lgb.train() throws an informative error if 'valids' contains lgb.Dataset objects but none have names", {
valids <- list(
lgb.Dataset(matrix(rnorm(10L), 5L, 2L))
, lgb.Dataset(matrix(rnorm(10L), 2L, 5L))
)
expect_error({
bst <- lgb.train(
params = list(objective = "regression", metric = "l2,l1")
, data = lgb.Dataset(train$data, label = train$label)
, 10L
, valids = valids
)
}, regexp = "each element of valids must have a name")
})
...@@ -99,3 +99,30 @@ test_that("lgb.Dataset$setinfo() should convert 'group' to integer", { ...@@ -99,3 +99,30 @@ test_that("lgb.Dataset$setinfo() should convert 'group' to integer", {
ds$setinfo("group", group_as_numeric) ds$setinfo("group", group_as_numeric)
expect_identical(ds$getinfo("group"), as.integer(group_as_numeric)) expect_identical(ds$getinfo("group"), as.integer(group_as_numeric))
}) })
test_that("lgb.Dataset should throw an error if 'reference' is provided but of the wrong format", {
data(agaricus.test, package = "lightgbm")
test_data <- agaricus.test$data[1L:100L, ]
test_label <- agaricus.test$label[1L:100L]
# Try to trick lgb.Dataset() into accepting bad input
expect_error({
dtest <- lgb.Dataset(
data = test_data
, label = test_label
, reference = data.frame(x = seq_len(10L), y = seq_len(10L))
)
}, regexp = "reference must be a")
})
test_that("Dataset$new() should throw an error if 'predictor' is provided but of the wrong format", {
data(agaricus.test, package = "lightgbm")
test_data <- agaricus.test$data[1L:100L, ]
test_label <- agaricus.test$label[1L:100L]
expect_error({
dtest <- Dataset$new(
data = test_data
, label = test_label
, predictor = data.frame(x = seq_len(10L), y = seq_len(10L))
)
}, regexp = "predictor must be a", fixed = TRUE)
})
context("lgb.check.r6.class")
test_that("lgb.check.r6.class() should return FALSE for NULL input", {
expect_false(lgb.check.r6.class(NULL, "lgb.Dataset"))
})
test_that("lgb.check.r6.class() should return FALSE for non-R6 inputs", {
x <- 5L
class(x) <- "lgb.Dataset"
expect_false(lgb.check.r6.class(x, "lgb.Dataset"))
})
test_that("lgb.check.r6.class() should correctly identify lgb.Dataset", {
data("agaricus.train", package = "lightgbm")
train <- agaricus.train
ds <- lgb.Dataset(train$data, label = train$label)
expect_true(lgb.check.r6.class(ds, "lgb.Dataset"))
expect_false(lgb.check.r6.class(ds, "lgb.Predictor"))
expect_false(lgb.check.r6.class(ds, "lgb.Booster"))
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment