Unverified Commit b99a8f01 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] raise informative errors directly when Booster creation fails (#5014)

parent 2f27d4b2
...@@ -26,108 +26,90 @@ Booster <- R6::R6Class( ...@@ -26,108 +26,90 @@ Booster <- R6::R6Class(
modelfile = NULL, modelfile = NULL,
model_str = NULL) { model_str = NULL) {
# Create parameters and handle
handle <- NULL handle <- NULL
# Attempts to create a handle for the dataset if (!is.null(train_set)) {
try({
# Check if training dataset is not null
if (!is.null(train_set)) {
# Check if training dataset is lgb.Dataset or not
if (!lgb.is.Dataset(train_set)) {
stop("lgb.Booster: Can only use lgb.Dataset as training data")
}
train_set_handle <- train_set$.__enclos_env__$private$get_handle()
params <- utils::modifyList(params, train_set$get_params())
params_str <- lgb.params2str(params = params)
# Store booster handle
handle <- .Call(
LGBM_BoosterCreate_R
, train_set_handle
, params_str
)
# Create private booster information
private$train_set <- train_set
private$train_set_version <- train_set$.__enclos_env__$private$version
private$num_dataset <- 1L
private$init_predictor <- train_set$.__enclos_env__$private$predictor
# Check if predictor is existing
if (!is.null(private$init_predictor)) {
# Merge booster
.Call(
LGBM_BoosterMerge_R
, handle
, private$init_predictor$.__enclos_env__$private$handle
)
}
# Check current iteration
private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
} else if (!is.null(modelfile)) { if (!lgb.is.Dataset(train_set)) {
stop("lgb.Booster: Can only use lgb.Dataset as training data")
}
train_set_handle <- train_set$.__enclos_env__$private$get_handle()
params <- utils::modifyList(params, train_set$get_params())
params_str <- lgb.params2str(params = params)
# Store booster handle
handle <- .Call(
LGBM_BoosterCreate_R
, train_set_handle
, params_str
)
# Do we have a model file as character? # Create private booster information
if (!is.character(modelfile)) { private$train_set <- train_set
stop("lgb.Booster: Can only use a string as model file path") private$train_set_version <- train_set$.__enclos_env__$private$version
} private$num_dataset <- 1L
private$init_predictor <- train_set$.__enclos_env__$private$predictor
modelfile <- path.expand(modelfile) if (!is.null(private$init_predictor)) {
# Create booster from model # Merge booster
handle <- .Call( .Call(
LGBM_BoosterCreateFromModelfile_R LGBM_BoosterMerge_R
, modelfile , handle
, private$init_predictor$.__enclos_env__$private$handle
) )
} else if (!is.null(model_str)) { }
# Do we have a model_str as character/raw? # Check current iteration
if (!is.raw(model_str) && !is.character(model_str)) { private$is_predicted_cur_iter <- c(private$is_predicted_cur_iter, FALSE)
stop("lgb.Booster: Can only use a character/raw vector as model_str")
}
# Create booster from model } else if (!is.null(modelfile)) {
handle <- .Call(
LGBM_BoosterLoadModelFromString_R
, model_str
)
} else { # Do we have a model file as character?
if (!is.character(modelfile)) {
stop("lgb.Booster: Can only use a string as model file path")
}
# Booster non existent modelfile <- path.expand(modelfile)
stop(
"lgb.Booster: Need at least either training dataset, "
, "model file, or model_str to create booster instance"
)
} # Create booster from model
handle <- .Call(
LGBM_BoosterCreateFromModelfile_R
, modelfile
)
}) } else if (!is.null(model_str)) {
# Check whether the handle was created properly if it was not stopped earlier by a stop call # Do we have a model_str as character/raw?
if (isTRUE(lgb.is.null.handle(x = handle))) { if (!is.raw(model_str) && !is.character(model_str)) {
stop("lgb.Booster: Can only use a character/raw vector as model_str")
}
stop("lgb.Booster: cannot create Booster handle") # Create booster from model
handle <- .Call(
LGBM_BoosterLoadModelFromString_R
, model_str
)
} else { } else {
# Create class # Booster non existent
class(handle) <- "lgb.Booster.handle" stop(
private$handle <- handle "lgb.Booster: Need at least either training dataset, "
private$num_class <- 1L , "model file, or model_str to create booster instance"
.Call(
LGBM_BoosterGetNumClasses_R
, private$handle
, private$num_class
) )
} }
class(handle) <- "lgb.Booster.handle"
private$handle <- handle
private$num_class <- 1L
.Call(
LGBM_BoosterGetNumClasses_R
, private$handle
, private$num_class
)
self$params <- params self$params <- params
return(invisible(NULL)) return(invisible(NULL))
......
...@@ -947,7 +947,76 @@ test_that("Booster$new() using a Dataset with a null handle should raise an info ...@@ -947,7 +947,76 @@ test_that("Booster$new() using a Dataset with a null handle should raise an info
verbose = VERBOSITY verbose = VERBOSITY
) )
) )
}, regexp = "lgb.Booster: cannot create Booster handle") }, regexp = "Attempting to create a Dataset without any raw data")
})
test_that("Booster$new() raises informative errors for malformed inputs", {
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
# no inputs
expect_error({
Booster$new()
}, regexp = "lgb.Booster: Need at least either training dataset, model file, or model_str")
# unrecognized objective
expect_error({
Booster$new(
params = list(objective = "not_a_real_objective")
, train_set = dtrain
)
}, regexp = "Unknown objective type name: not_a_real_objective")
# train_set is not a Dataset
expect_error({
Booster$new(
train_set = data.table::data.table(rnorm(1L:10L))
)
}, regexp = "lgb.Booster: Can only use lgb.Dataset as training data")
# model file isn't a string
expect_error({
Booster$new(
modelfile = list()
)
}, regexp = "lgb.Booster: Can only use a string as model file path")
# model file doesn't exist
expect_error({
Booster$new(
params = list()
, modelfile = "file-that-does-not-exist.model"
)
}, regexp = "Could not open file-that-does-not-exist.model")
# model file doesn't contain a valid LightGBM model
model_file <- tempfile(fileext = ".model")
writeLines(
text = c("make", "good", "predictions")
, con = model_file
)
expect_error({
Booster$new(
params = list()
, modelfile = model_file
)
}, regexp = "Unknown model format or submodel type in model file")
# malformed model string
expect_error({
Booster$new(
params = list()
, model_str = "a\nb\n"
)
}, regexp = "Model file doesn't specify the number of classes")
# model string isn't character or raw
expect_error({
Booster$new(
model_str = numeric()
)
}, regexp = "lgb.Booster: Can only use a character/raw vector as model_str")
}) })
# this is almost identical to the test above it, but for lgb.cv(). A lot of code # this is almost identical to the test above it, but for lgb.cv(). A lot of code
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment