Unverified Commit d9064282 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] moved parameter validations up earlier in function calls (#2663)

parent 08fd53cd
......@@ -29,11 +29,10 @@ cb.reset.parameters <- function(new_params) {
# Run some checks in the beginning
init <- function(env) {
# Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1L
# Check for model environment
if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) }
if (is.null(env$model)) {
stop("Env should have a ", sQuote("model"))
}
# Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos
......@@ -50,6 +49,9 @@ cb.reset.parameters <- function(new_params) {
)
}
# Store boosting rounds
nrounds <<- env$end_iteration - env$begin_iteration + 1L
# Check parameter names
for (n in pnames) {
......@@ -285,14 +287,14 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Initialization function
init <- function(env) {
# Store evaluation length
eval_len <<- length(env$eval_list)
# Early stopping cannot work without metrics
if (eval_len == 0L) {
if (length(env$eval_list) == 0L) {
stop("For early stopping, valids must have at least one element")
}
# Store evaluation length
eval_len <<- length(env$eval_list)
# Check if verbose or not
if (isTRUE(verbose)) {
cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "")
......
......@@ -781,15 +781,21 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
}
# Return new booster
if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename")
if (!is.null(filename)) return(invisible(Booster$new(modelfile = filename)))
if (!is.null(filename) && !file.exists(filename)) {
stop("lgb.load: file does not exist for supplied filename")
}
if (!is.null(filename)) {
return(invisible(Booster$new(modelfile = filename)))
}
# Load from model_str
if (!is.null(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be character")
}
# Return new booster
if (!is.null(model_str)) return(invisible(Booster$new(model_str = model_str)))
if (!is.null(model_str)) {
return(invisible(Booster$new(model_str = model_str)))
}
}
......@@ -831,8 +837,8 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
}
# Check if file name is character
if (!is.character(filename)) {
stop("lgb.save: filename should be a character")
if (!(is.character(filename) && length(filename) == 1L)) {
stop("lgb.save: filename should be a string")
}
# Store booster
......
......@@ -32,6 +32,14 @@ Dataset <- R6::R6Class(
info = list(),
...) {
# validate inputs early to avoid unnecessary computation
if (!(is.null(reference) || lgb.check.r6.class(reference, "lgb.Dataset"))) {
stop("lgb.Dataset: If provided, reference must be a ", sQuote("lgb.Dataset"))
}
if (!(is.null(predictor) || lgb.check.r6.class(predictor, "lgb.Predictor"))) {
stop("lgb.Dataset: If provided, predictor must be a ", sQuote("lgb.Predictor"))
}
# Check for additional parameters
additional_params <- list(...)
......@@ -56,20 +64,6 @@ Dataset <- R6::R6Class(
}
# Check for dataset reference
if (!is.null(reference)) {
if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
stop("lgb.Dataset: Can only use ", sQuote("lgb.Dataset"), " as reference")
}
}
# Check for predictor reference
if (!is.null(predictor)) {
if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
stop("lgb.Dataset: Only can use ", sQuote("lgb.Predictor"), " as predictor")
}
}
# Check for matrix format
if (is.matrix(data)) {
# Check whether matrix is the correct type first ("double")
......
......@@ -22,7 +22,7 @@ CVBooster <- R6::R6Class(
#' @description Cross validation logic used by LightGBM
#' @inheritParams lgb_shared_params
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label vector of response values. Should be provided only when data is an R-matrix.
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
......@@ -95,6 +95,19 @@ lgb.cv <- function(params = list()
, ...
) {
# validate parameters
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
# If 'data' is not an lgb.Dataset, try to construct one using 'label'
if (!lgb.is.Dataset(data)) {
if (is.null(label)) {
stop("'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'")
}
data <- lgb.Dataset(data, label = label)
}
# Setup temporary variables
params <- append(params, list(...))
params$verbose <- verbose
......@@ -103,10 +116,6 @@ lgb.cv <- function(params = list()
fobj <- NULL
feval <- NULL
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
......@@ -141,14 +150,6 @@ lgb.cv <- function(params = list()
end_iteration <- begin_iteration + nrounds - 1L
}
# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
if (is.null(label)) {
stop("Labels must be provided for lgb.cv")
}
data <- lgb.Dataset(data, label = label)
}
# Check for weights
if (!is.null(weight)) {
data$setinfo("weight", weight)
......
......@@ -36,7 +36,7 @@
lgb.importance <- function(model, percentage = TRUE) {
# Check if model is a lightgbm model
if (!inherits(model, "lgb.Booster")) {
if (!lgb.is.Booster(model)) {
stop("'model' has to be an object of class lgb.Booster")
}
......
......@@ -65,6 +65,23 @@ lgb.train <- function(params = list(),
reset_data = FALSE,
...) {
# validate inputs early to avoid unnecessary computation
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data must be an lgb.Dataset instance")
}
if (length(valids) > 0L) {
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1L)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements")
}
evnames <- names(valids)
if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of valids must have a name")
}
}
# Setup temporary variables
additional_params <- list(...)
params <- append(params, additional_params)
......@@ -74,10 +91,6 @@ lgb.train <- function(params = list(),
fobj <- NULL
feval <- NULL
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
# Check for objective (function or not)
if (is.function(params$objective)) {
fobj <- params$objective
......@@ -112,30 +125,6 @@ lgb.train <- function(params = list(),
end_iteration <- begin_iteration + nrounds - 1L
}
# Check for training dataset type correctness
if (!lgb.is.Dataset(data)) {
stop("lgb.train: data only accepts lgb.Dataset object")
}
# Check for validation dataset type correctness
if (length(valids) > 0L) {
# One or more validation dataset
# Check for list as input and type correctness by object
if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1L)))) {
stop("lgb.train: valids must be a list of lgb.Dataset elements")
}
# Attempt to get names
evnames <- names(valids)
# Check for names existance
if (is.null(evnames) || !all(nzchar(evnames))) {
stop("lgb.train: each element of the valids must have a name tag")
}
}
# Update parameters with parsed parameters
data$update_params(params)
......
#' @name lgb_shared_params
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @param callbacks List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training
#' @param callbacks list of callback functions
#' List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
#' may allow you to pass other types of data like \code{matrix} and then separately supply
#' \code{label} as a keyword argument.
#' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data
#' and one metric. If there's more than one, will check all of them
#' except the training data. Returns the model with (best_iter + early_stopping_rounds).
......@@ -57,11 +60,14 @@ lightgbm <- function(data,
callbacks = list(),
...) {
# Set data to a temporary variable
dtrain <- data
# validate inputs early to avoid unnecessary computation
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}
# Set data to a temporary variable
dtrain <- data
# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
if (!lgb.is.Dataset(dtrain)) {
dtrain <- lgb.Dataset(data, label = label, weight = weight)
......
......@@ -31,13 +31,15 @@ lgb.cv(
\arguments{
\item{params}{List of parameters}
\item{data}{a \code{lgb.Dataset} object, used for training}
\item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{nrounds}{number of training rounds}
\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.}
\item{label}{vector of response values. Should be provided only when data is an R-matrix.}
\item{label}{Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}}
\item{weight}{vector of response values. If not NULL, will set to dataset}
......
......@@ -26,7 +26,9 @@ lgb.train(
\arguments{
\item{params}{List of parameters}
\item{data}{a \code{lgb.Dataset} object, used for training}
\item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{nrounds}{number of training rounds}
......
......@@ -6,7 +6,9 @@
\arguments{
\item{callbacks}{List of callback functions that are applied at each iteration.}
\item{data}{a \code{lgb.Dataset} object, used for training}
\item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{early_stopping_rounds}{int. Activates early stopping. Requires at least one validation data
and one metric. If there's more than one, will check all of them
......
......@@ -20,7 +20,9 @@ lightgbm(
)
}
\arguments{
\item{data}{a \code{lgb.Dataset} object, used for training}
\item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{label}{Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}}
......
context("basic functions")
context("lightgbm()")
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
......@@ -70,6 +70,20 @@ test_that("use of multiple eval metrics works", {
expect_false(is.null(bst$record_evals))
})
test_that("lightgbm() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2,l1")
for (nround_value in c(-10L, 0L)) {
expect_error({
bst <- lightgbm(
data = dtrain
, params = params
, nrounds = nround_value
)
}, "nrounds should be greater than zero")
}
})
test_that("training continuation works", {
testthat::skip("This test is currently broken. See issue #2468 for details.")
......@@ -103,6 +117,7 @@ test_that("training continuation works", {
expect_lt(abs(err_bst - err_bst2), 0.01)
})
context("lgb.cv()")
test_that("cv works", {
dtrain <- lgb.Dataset(train$data, label = train$label)
......@@ -118,3 +133,122 @@ test_that("cv works", {
)
expect_false(is.null(bst$record_evals))
})
test_that("lgb.cv() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2,l1")
for (nround_value in c(-10L, 0L)) {
expect_error({
bst <- lgb.cv(
params
, dtrain
, nround_value
, nfold = 5L
, min_data = 1L
)
}, "nrounds should be greater than zero")
}
})
test_that("lgb.cv() throws an informative error is 'data' is not an lgb.Dataset and labels are not given", {
bad_values <- list(
4L
, "hello"
, list(a = TRUE, b = seq_len(10L))
, data.frame(x = seq_len(5L), y = seq_len(5L))
, data.table::data.table(x = seq_len(5L), y = seq_len(5L))
, matrix(data = seq_len(10L), 2L, 5L)
)
for (val in bad_values) {
expect_error({
bst <- lgb.cv(
params = list(objective = "regression", metric = "l2,l1")
, data = val
, 10L
, nfold = 5L
, min_data = 1L
)
}, regexp = "'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'", fixed = TRUE)
}
})
context("lgb.train()")
test_that("lgb.train() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2,l1")
for (nround_value in c(-10L, 0L)) {
expect_error({
bst <- lgb.train(
params
, dtrain
, nround_value
)
}, "nrounds should be greater than zero")
}
})
test_that("lgb.train() throws an informative error if 'data' is not an lgb.Dataset", {
bad_values <- list(
4L
, "hello"
, list(a = TRUE, b = seq_len(10L))
, data.frame(x = seq_len(5L), y = seq_len(5L))
, data.table::data.table(x = seq_len(5L), y = seq_len(5L))
, matrix(data = seq_len(10L), 2L, 5L)
)
for (val in bad_values) {
expect_error({
bst <- lgb.train(
params = list(objective = "regression", metric = "l2,l1")
, data = val
, 10L
)
}, regexp = "data must be an lgb.Dataset instance", fixed = TRUE)
}
})
test_that("lgb.train() throws an informative error if 'valids' is not a list of lgb.Dataset objects", {
valids <- list(
"valid1" = data.frame(x = rnorm(5L), y = rnorm(5L))
, "valid2" = data.frame(x = rnorm(5L), y = rnorm(5L))
)
expect_error({
bst <- lgb.train(
params = list(objective = "regression", metric = "l2,l1")
, data = lgb.Dataset(train$data, label = train$label)
, 10L
, valids = valids
)
}, regexp = "valids must be a list of lgb.Dataset elements")
})
test_that("lgb.train() errors if 'valids' is a list of lgb.Dataset objects but some do not have names", {
valids <- list(
"valid1" = lgb.Dataset(matrix(rnorm(10L), 5L, 2L))
, lgb.Dataset(matrix(rnorm(10L), 2L, 5L))
)
expect_error({
bst <- lgb.train(
params = list(objective = "regression", metric = "l2,l1")
, data = lgb.Dataset(train$data, label = train$label)
, 10L
, valids = valids
)
}, regexp = "each element of valids must have a name")
})
test_that("lgb.train() throws an informative error if 'valids' contains lgb.Dataset objects but none have names", {
valids <- list(
lgb.Dataset(matrix(rnorm(10L), 5L, 2L))
, lgb.Dataset(matrix(rnorm(10L), 2L, 5L))
)
expect_error({
bst <- lgb.train(
params = list(objective = "regression", metric = "l2,l1")
, data = lgb.Dataset(train$data, label = train$label)
, 10L
, valids = valids
)
}, regexp = "each element of valids must have a name")
})
......@@ -99,3 +99,30 @@ test_that("lgb.Dataset$setinfo() should convert 'group' to integer", {
ds$setinfo("group", group_as_numeric)
expect_identical(ds$getinfo("group"), as.integer(group_as_numeric))
})
test_that("lgb.Dataset should throw an error if 'reference' is provided but of the wrong format", {
data(agaricus.test, package = "lightgbm")
test_data <- agaricus.test$data[1L:100L, ]
test_label <- agaricus.test$label[1L:100L]
# Try to trick lgb.Dataset() into accepting bad input
expect_error({
dtest <- lgb.Dataset(
data = test_data
, label = test_label
, reference = data.frame(x = seq_len(10L), y = seq_len(10L))
)
}, regexp = "reference must be a")
})
test_that("Dataset$new() should throw an error if 'predictor' is provided but of the wrong format", {
data(agaricus.test, package = "lightgbm")
test_data <- agaricus.test$data[1L:100L, ]
test_label <- agaricus.test$label[1L:100L]
expect_error({
dtest <- Dataset$new(
data = test_data
, label = test_label
, predictor = data.frame(x = seq_len(10L), y = seq_len(10L))
)
}, regexp = "predictor must be a", fixed = TRUE)
})
context("lgb.check.r6.class")
test_that("lgb.check.r6.class() should return FALSE for NULL input", {
expect_false(lgb.check.r6.class(NULL, "lgb.Dataset"))
})
test_that("lgb.check.r6.class() should return FALSE for non-R6 inputs", {
x <- 5L
class(x) <- "lgb.Dataset"
expect_false(lgb.check.r6.class(x, "lgb.Dataset"))
})
test_that("lgb.check.r6.class() should correctly identify lgb.Dataset", {
data("agaricus.train", package = "lightgbm")
train <- agaricus.train
ds <- lgb.Dataset(train$data, label = train$label)
expect_true(lgb.check.r6.class(ds, "lgb.Dataset"))
expect_false(lgb.check.r6.class(ds, "lgb.Predictor"))
expect_false(lgb.check.r6.class(ds, "lgb.Booster"))
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment