Unverified Commit f8010d61 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] fix segfaults caused by missing Booster and Dataset handles (fixes #4208) (#4586)



* [R-package] fix segfaults caused by missing Booster and Dataset handles (fixes #4208)

* fix test errors

* fixes for cpplint

* Update R-package/tests/testthat/test_dataset.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* fix tests

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* move asserts inside try-catch
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent d4629727
......@@ -226,9 +226,18 @@ Dataset <- R6::R6Class(
ref_handle <- private$reference$.__enclos_env__$private$get_handle()
}
# Not subsetting
# not subsetting, constructing from raw data
if (is.null(private$used_indices)) {
if (is.null(private$raw_data)) {
stop(paste0(
"Attempting to create a Dataset without any raw data. "
, "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
, "To avoid this error in the future, use lgb.Dataset.save() or "
, "Dataset$save_binary() to save lightgbm Datasets."
))
}
# Are we using a data file?
if (is.character(private$raw_data)) {
......
......@@ -90,6 +90,24 @@ void _DatasetFinalizer(SEXP handle) {
LGBM_DatasetFree_R(handle);
}
void _AssertBoosterHandleNotNull(SEXP handle) {
if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
Rf_error(
"Attempting to use a Booster which no longer exists. "
"This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). "
"To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.");
}
}
void _AssertDatasetHandleNotNull(SEXP handle) {
if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
Rf_error(
"Attempting to use a Dataset which no longer exists. "
"This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
"To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
}
}
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
SEXP parameters,
SEXP reference) {
......@@ -172,6 +190,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP len_used_row_indices,
SEXP parameters) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len);
......@@ -195,6 +214,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP feature_names) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
std::vector<const char*> vec_sptr;
int len = static_cast<int>(vec_names.size());
......@@ -211,6 +231,7 @@ SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
SEXP feature_names;
int len = 0;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
......@@ -258,6 +279,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
SEXP filename) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
filename_ptr));
......@@ -281,6 +303,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
SEXP field_data,
SEXP num_element) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int len = Rf_asInteger(num_element);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
if (!strcmp("group", name) || !strcmp("query", name)) {
......@@ -309,6 +332,7 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
SEXP field_name,
SEXP field_data) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int out_type = 0;
......@@ -343,6 +367,7 @@ SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
SEXP field_name,
SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int out_type = 0;
......@@ -370,6 +395,7 @@ SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int nrow;
CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
INTEGER(out)[0] = nrow;
......@@ -380,6 +406,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int nfeature;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
INTEGER(out)[0] = nfeature;
......@@ -406,6 +433,7 @@ SEXP LGBM_BoosterFree_R(SEXP handle) {
SEXP LGBM_BoosterCreate_R(SEXP train_data,
SEXP parameters) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(train_data);
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
BoosterHandle handle = nullptr;
......@@ -448,6 +476,8 @@ SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
SEXP LGBM_BoosterMerge_R(SEXP handle,
SEXP other_handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertBoosterHandleNotNull(other_handle);
CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
return R_NilValue;
R_API_END();
......@@ -456,6 +486,8 @@ SEXP LGBM_BoosterMerge_R(SEXP handle,
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
SEXP valid_data) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertDatasetHandleNotNull(valid_data);
CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
return R_NilValue;
R_API_END();
......@@ -464,6 +496,8 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle,
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
SEXP train_data) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertDatasetHandleNotNull(train_data);
CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
return R_NilValue;
R_API_END();
......@@ -472,6 +506,7 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP parameters) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
UNPROTECT(1);
......@@ -482,6 +517,7 @@ SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int num_class;
CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
INTEGER(out)[0] = num_class;
......@@ -491,6 +527,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int is_finished = 0;
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
return R_NilValue;
......@@ -502,6 +539,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
SEXP hess,
SEXP len) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int is_finished = 0;
int int_len = Rf_asInteger(len);
std::vector<float> tgrad(int_len), thess(int_len);
......@@ -517,6 +555,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
return R_NilValue;
R_API_END();
......@@ -525,6 +564,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out_iteration;
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
INTEGER(out)[0] = out_iteration;
......@@ -535,6 +575,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
return R_NilValue;
......@@ -544,6 +585,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
return R_NilValue;
......@@ -553,6 +595,7 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP eval_names;
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
......@@ -602,6 +645,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle,
SEXP data_idx,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
double* ptr_ret = REAL(out_result);
......@@ -616,6 +660,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
SEXP data_idx,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int64_t len;
CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
INTEGER(out)[0] = static_cast<int>(len);
......@@ -627,6 +672,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle,
SEXP data_idx,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
......@@ -659,6 +705,7 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
SEXP parameter,
SEXP result_filename) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
......@@ -680,6 +727,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
SEXP num_iteration,
SEXP out_len) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
......@@ -704,6 +752,7 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
SEXP parameter,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = INTEGER(indptr);
const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
......@@ -735,6 +784,7 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
SEXP parameter,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
......@@ -755,6 +805,7 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP feature_importance_type,
SEXP filename) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
UNPROTECT(1);
......@@ -767,6 +818,7 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
......@@ -791,6 +843,7 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
......
......@@ -526,3 +526,51 @@ test_that("lgb.Dataset: should be able to create a Dataset from a text file with
expect_identical(dtrain$get_params(), list(header = FALSE))
expect_identical(dtrain$dim(), c(100L, 2L))
})
test_that("Dataset: method calls on a Dataset with a null handle should raise an informative error and not segfault", {
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
dtrain$construct()
dvalid <- dtrain$create_valid(
data = train$data[seq_len(100L), ]
, label = train$label[seq_len(100L)]
)
dvalid$construct()
tmp_file <- tempfile(fileext = ".rds")
saveRDS(dtrain, tmp_file)
rm(dtrain)
dtrain <- readRDS(tmp_file)
expect_error({
dtrain$construct()
}, regexp = "Attempting to create a Dataset without any raw data")
expect_error({
dtrain$dim()
}, regexp = "cannot get dimensions before dataset has been constructed")
expect_error({
dtrain$get_colnames()
}, regexp = "cannot get column names before dataset has been constructed")
expect_error({
dtrain$save_binary(fname = tempfile(fileext = ".bin"))
}, regexp = "Attempting to create a Dataset without any raw data")
expect_error({
dtrain$set_categorical_feature(categorical_feature = 1L)
}, regexp = "cannot set categorical feature after freeing raw data")
expect_error({
dtrain$set_reference(reference = dvalid)
}, regexp = "cannot set reference after freeing raw data")
tmp_valid_file <- tempfile(fileext = ".rds")
saveRDS(dvalid, tmp_valid_file)
rm(dvalid)
dvalid <- readRDS(tmp_valid_file)
dtrain <- lgb.Dataset(
train$data
, label = train$label
, free_raw_data = FALSE
)
dtrain$construct()
expect_error({
dtrain$set_reference(reference = dvalid)
}, regexp = "cannot get column names before dataset has been constructed")
})
......@@ -819,6 +819,93 @@ test_that("early_stopping, num_iterations are stored correctly in model string e
})
test_that("Booster: method calls Booster with a null handle should raise an informative error and not segfault", {
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
bst <- lgb.train(
params = list(
objective = "regression"
, metric = "l2"
, num_leaves = 8L
)
, data = dtrain
, verbose = -1L
, nrounds = 5L
, valids = list(
train = dtrain
)
)
tmp_file <- tempfile(fileext = ".rds")
saveRDS(bst, tmp_file)
rm(bst)
bst <- readRDS(tmp_file)
.expect_booster_error <- function(object) {
error_regexp <- "Attempting to use a Booster which no longer exists"
expect_error(object, regexp = error_regexp)
}
.expect_booster_error({
bst$current_iter()
})
.expect_booster_error({
bst$dump_model()
})
.expect_booster_error({
bst$eval(data = dtrain, name = "valid")
})
.expect_booster_error({
bst$eval_train()
})
.expect_booster_error({
bst$lower_bound()
})
.expect_booster_error({
bst$predict(data = train$data[seq_len(5L), ])
})
.expect_booster_error({
bst$reset_parameter(params = list(learning_rate = 0.123))
})
.expect_booster_error({
bst$rollback_one_iter()
})
.expect_booster_error({
bst$save()
})
.expect_booster_error({
bst$save_model(filename = tempfile(fileext = ".model"))
})
.expect_booster_error({
bst$save_model_to_string()
})
.expect_booster_error({
bst$update()
})
.expect_booster_error({
bst$upper_bound()
})
predictor <- bst$to_predictor()
.expect_booster_error({
predictor$current_iter()
})
.expect_booster_error({
predictor$predict(data = train$data[seq_len(5L), ])
})
})
test_that("Booster$new() using a Dataset with a null handle should raise an informative error and not segfault", {
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
dtrain$construct()
tmp_file <- tempfile(fileext = ".bin")
saveRDS(dtrain, tmp_file)
rm(dtrain)
dtrain <- readRDS(tmp_file)
expect_error({
bst <- Booster$new(train_set = dtrain)
}, regexp = "lgb.Booster: cannot create Booster handle")
})
# this is almost identical to the test above it, but for lgb.cv(). A lot of code
# is duplicated between lgb.train() and lgb.cv(), and this will catch cases where
# one is updated and the other isn't
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment