Unverified Commit 12b15273 authored by david-cortes's avatar david-cortes Committed by GitHub
Browse files

[R-package] enable saving Booster with saveRDS() and loading it with readRDS()...


[R-package] enable saving Booster with saveRDS() and loading it with readRDS() (fixes #4296) (#4685)

* idiomatic serialization

* linter

* linter, namespace

* comments, linter, fix failing test

* standardize error messages for null handles

* auto-restore handle in more functions

* linter

* missing declaration

* correct wrong signature

* fix docs

* Update R-package/R/lgb.train.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.drop_serialized.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.restore_handle.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.restore_handle.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.make_serializable.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* move 'restore_handle' from feature importance to dump method

* missing header

* move arguments order, update docs

* linter

* avoid leaving files in working directory

* add test for save_model=NULL

* missing comma

* Update R-package/R/lgb.restore_handle.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update R-package/src/lightgbm_R.cpp
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* change name of error function

* update comment

* restore old serialization functions but set as deprecated

* Update R-package/R/readRDS.lgb.Booster.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update R-package/R/saveRDS.lgb.Booster.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* update docs

* Update R-package/R/readRDS.lgb.Booster.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/saveRDS.lgb.Booster.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/tests/testthat/test_basic.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/readRDS.lgb.Booster.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* comments

* fix variable name

* restore serialization test for linear models

* Update R-package/R/lightgbm.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* update docs

* fix issues with null terminator
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent f54e32f7
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
% Please edit documentation in R/saveRDS.lgb.Booster.R % Please edit documentation in R/saveRDS.lgb.Booster.R
\name{saveRDS.lgb.Booster} \name{saveRDS.lgb.Booster}
\alias{saveRDS.lgb.Booster} \alias{saveRDS.lgb.Booster}
\title{saveRDS for \code{lgb.Booster} models} \title{saveRDS for \code{lgb.Booster} models (DEPRECATED)}
\usage{ \usage{
saveRDS.lgb.Booster( saveRDS.lgb.Booster(
object, object,
...@@ -38,8 +38,10 @@ compression to be used. Ignored if file is a connection.} ...@@ -38,8 +38,10 @@ compression to be used. Ignored if file is a connection.}
NULL invisibly. NULL invisibly.
} }
\description{ \description{
Attempts to save a model using RDS. Has an additional parameter (\code{raw}) Calls \code{saveRDS} on an \code{lgb.Booster} object, making it serializable before the call if
which decides whether to save the raw model or not. it isn't already.
\bold{This function throws a warning and will be removed in future versions.}
} }
\examples{ \examples{
\donttest{ \donttest{
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <algorithm>
#define COL_MAJOR (0) #define COL_MAJOR (0)
...@@ -60,6 +61,10 @@ SEXP wrapped_R_string(void *len) { ...@@ -60,6 +61,10 @@ SEXP wrapped_R_string(void *len) {
return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len))); return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
} }
SEXP wrapped_R_raw(void *len) {
return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}
SEXP wrapped_Rf_mkChar(void *txt) { SEXP wrapped_Rf_mkChar(void *txt) {
return Rf_mkChar(reinterpret_cast<char*>(txt)); return Rf_mkChar(reinterpret_cast<char*>(txt));
} }
...@@ -75,6 +80,10 @@ SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) { ...@@ -75,6 +80,10 @@ SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token); return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
} }
SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) { SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token); return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
} }
...@@ -90,12 +99,17 @@ void _DatasetFinalizer(SEXP handle) { ...@@ -90,12 +99,17 @@ void _DatasetFinalizer(SEXP handle) {
LGBM_DatasetFree_R(handle); LGBM_DatasetFree_R(handle);
} }
SEXP LGBM_NullBoosterHandleError_R() {
Rf_error(
"Attempting to use a Booster which no longer exists and/or cannot be restored. "
"This can happen if you have called Booster$finalize() "
"or if this Booster was saved through saveRDS() using 'serializable=FALSE'.");
return R_NilValue;
}
void _AssertBoosterHandleNotNull(SEXP handle) { void _AssertBoosterHandleNotNull(SEXP handle) {
if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) { if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
Rf_error( LGBM_NullBoosterHandleError_R();
"Attempting to use a Booster which no longer exists. "
"This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). "
"To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.");
} }
} }
...@@ -462,13 +476,30 @@ SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) { ...@@ -462,13 +476,30 @@ SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) { SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
R_API_BEGIN(); R_API_BEGIN();
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
SEXP temp = NULL;
int n_protected = 1;
int out_num_iterations = 0; int out_num_iterations = 0;
const char* model_str_ptr = CHAR(PROTECT(Rf_asChar(model_str))); const char* model_str_ptr = nullptr;
switch (TYPEOF(model_str)) {
case RAWSXP: {
model_str_ptr = reinterpret_cast<const char*>(RAW(model_str));
break;
}
case CHARSXP: {
model_str_ptr = reinterpret_cast<const char*>(CHAR(model_str));
break;
}
case STRSXP: {
temp = PROTECT(STRING_ELT(model_str, 0));
n_protected++;
model_str_ptr = reinterpret_cast<const char*>(CHAR(temp));
}
}
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle)); CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
R_SetExternalPtrAddr(ret, handle); R_SetExternalPtrAddr(ret, handle);
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(2); UNPROTECT(n_protected);
return ret; return ret;
R_API_END(); R_API_END();
} }
...@@ -828,20 +859,19 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, ...@@ -828,20 +859,19 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP cont_token = PROTECT(R_MakeUnwindCont()); SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN(); R_API_BEGIN();
_AssertBoosterHandleNotNull(handle); _AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0; int64_t out_len = 0;
int64_t buf_len = 1024 * 1024; int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration); int num_iter = Rf_asInteger(num_iteration);
int importance_type = Rf_asInteger(feature_importance_type); int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
// if the model string was larger than the initial buffer, call the function again, writing directly to the R object
if (out_len > buf_len) { if (out_len > buf_len) {
inner_char_buf.resize(out_len); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); } else {
std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
} }
model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
UNPROTECT(2); UNPROTECT(2);
return model_str; return model_str;
R_API_END(); R_API_END();
...@@ -936,6 +966,7 @@ static const R_CallMethodDef CallEntries[] = { ...@@ -936,6 +966,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4}, {"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3}, {"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3}, {"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
{"LGBM_NullBoosterHandleError_R" , (DL_FUNC) &LGBM_NullBoosterHandleError_R , 0},
{"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0}, {"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0},
{NULL, NULL, 0} {NULL, NULL, 0}
}; };
......
...@@ -20,6 +20,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_HandleIsNull_R( ...@@ -20,6 +20,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_HandleIsNull_R(
SEXP handle SEXP handle
); );
/*!
* \brief Throw a standardized error message when encountering a null Booster handle
* \return No return, will throw an error
*/
LIGHTGBM_C_EXPORT SEXP LGBM_NullBoosterHandleError_R();
// --- start Dataset interface // --- start Dataset interface
/*! /*!
......
...@@ -293,6 +293,23 @@ test_that("lightgbm() performs evaluation on validation sets if they are provide ...@@ -293,6 +293,23 @@ test_that("lightgbm() performs evaluation on validation sets if they are provide
expect_true(abs(bst$record_evals[["valid2"]][["binary_error"]][["eval"]][[1L]] - 0.02226317) < TOLERANCE) expect_true(abs(bst$record_evals[["valid2"]][["binary_error"]][["eval"]][[1L]] - 0.02226317) < TOLERANCE)
}) })
test_that("lightgbm() does not write model to disk if save_name=NULL", {
files_before <- list.files(getwd())
model <- lightgbm(
data = train$data
, label = train$label
, nrounds = 5L
, params = list(objective = "binary")
, verbose = 0L
, save_name = NULL
)
files_after <- list.files(getwd())
expect_equal(files_before, files_after)
})
context("training continuation") context("training continuation")
......
...@@ -158,7 +158,7 @@ test_that("lgb.load() gives the expected error messages given different incorrec ...@@ -158,7 +158,7 @@ test_that("lgb.load() gives the expected error messages given different incorrec
# if given, model_str should be a string # if given, model_str should be a string
expect_error({ expect_error({
lgb.load(model_str = c(4.0, 5.0, 6.0)) lgb.load(model_str = c(4.0, 5.0, 6.0))
}, regexp = "model_str should be character") }, regexp = "lgb.load: model_str should be a character/raw vector")
}) })
...@@ -841,6 +841,7 @@ test_that("Booster: method calls Booster with a null handle should raise an info ...@@ -841,6 +841,7 @@ test_that("Booster: method calls Booster with a null handle should raise an info
, valids = list( , valids = list(
train = dtrain train = dtrain
) )
, serializable = FALSE
) )
tmp_file <- tempfile(fileext = ".rds") tmp_file <- tempfile(fileext = ".rds")
saveRDS(bst, tmp_file) saveRDS(bst, tmp_file)
...@@ -875,7 +876,7 @@ test_that("Booster: method calls Booster with a null handle should raise an info ...@@ -875,7 +876,7 @@ test_that("Booster: method calls Booster with a null handle should raise an info
bst$rollback_one_iter() bst$rollback_one_iter()
}) })
.expect_booster_error({ .expect_booster_error({
bst$save() bst$save_raw()
}) })
.expect_booster_error({ .expect_booster_error({
bst$save_model(filename = tempfile(fileext = ".model")) bst$save_model(filename = tempfile(fileext = ".model"))
...@@ -991,9 +992,9 @@ test_that("params (including dataset params) should be stored in .rds file for B ...@@ -991,9 +992,9 @@ test_that("params (including dataset params) should be stored in .rds file for B
, train_set = dtrain , train_set = dtrain
) )
bst_file <- tempfile(fileext = ".rds") bst_file <- tempfile(fileext = ".rds")
saveRDS.lgb.Booster(bst, file = bst_file) expect_warning(saveRDS.lgb.Booster(bst, file = bst_file))
bst_from_file <- readRDS.lgb.Booster(file = bst_file) expect_warning(bst_from_file <- readRDS.lgb.Booster(file = bst_file))
expect_identical( expect_identical(
bst_from_file$params bst_from_file$params
, list( , list(
...@@ -1005,6 +1006,91 @@ test_that("params (including dataset params) should be stored in .rds file for B ...@@ -1005,6 +1006,91 @@ test_that("params (including dataset params) should be stored in .rds file for B
) )
}) })
context("saveRDS and readRDS work on Booster")
test_that("params (including dataset params) should be stored in .rds file for Booster", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
, params = list(
max_bin = 17L
)
)
params <- list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
)
bst <- Booster$new(
params = params
, train_set = dtrain
)
bst_file <- tempfile(fileext = ".rds")
saveRDS(bst, file = bst_file)
bst_from_file <- readRDS(file = bst_file)
expect_identical(
bst_from_file$params
, list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
, max_bin = 17L
)
)
})
test_that("Handle is automatically restored when calling predict", {
data(agaricus.train, package = "lightgbm")
bst <- lightgbm(agaricus.train$data, agaricus.train$label, nrounds = 5L, obj = "binary")
bst_file <- tempfile(fileext = ".rds")
saveRDS(bst, file = bst_file)
bst_from_file <- readRDS(file = bst_file)
pred_before <- predict(bst, agaricus.train$data)
pred_after <- predict(bst_from_file, agaricus.train$data)
expect_equal(pred_before, pred_after)
})
test_that("boosters with linear models at leaves work with saveRDS.lgb.Booster and readRDS.lgb.Booster", {
X <- matrix(rnorm(100L), ncol = 1L)
labels <- 2L * X + runif(nrow(X), 0L, 0.1)
dtrain <- lgb.Dataset(
data = X
, label = labels
)
params <- list(
objective = "regression"
, verbose = -1L
, metric = "mse"
, seed = 0L
, num_leaves = 2L
)
bst <- lgb.train(
data = dtrain
, nrounds = 10L
, params = params
)
expect_true(lgb.is.Booster(bst))
# save predictions, then write the model to a file and destroy it in R
preds <- predict(bst, X)
model_file <- tempfile(fileext = ".rds")
expect_warning(saveRDS.lgb.Booster(bst, file = model_file))
bst$finalize()
expect_null(bst$.__enclos_env__$private$handle)
rm(bst)
# load the booster and make predictions...should be the same
expect_warning({bst2 <- readRDS.lgb.Booster(file = model_file)})
preds2 <- predict(bst2, X)
expect_identical(preds, preds2)
})
test_that("boosters with linear models at leaves can be written to RDS and re-loaded successfully", { test_that("boosters with linear models at leaves can be written to RDS and re-loaded successfully", {
X <- matrix(rnorm(100L), ncol = 1L) X <- matrix(rnorm(100L), ncol = 1L)
labels <- 2L * X + runif(nrow(X), 0L, 0.1) labels <- 2L * X + runif(nrow(X), 0L, 0.1)
...@@ -1031,13 +1117,13 @@ test_that("boosters with linear models at leaves can be written to RDS and re-lo ...@@ -1031,13 +1117,13 @@ test_that("boosters with linear models at leaves can be written to RDS and re-lo
# save predictions, then write the model to a file and destroy it in R # save predictions, then write the model to a file and destroy it in R
preds <- predict(bst, X) preds <- predict(bst, X)
model_file <- tempfile(fileext = ".rds") model_file <- tempfile(fileext = ".rds")
saveRDS.lgb.Booster(bst, file = model_file) saveRDS(bst, file = model_file)
bst$finalize() bst$finalize()
expect_null(bst$.__enclos_env__$private$handle) expect_null(bst$.__enclos_env__$private$handle)
rm(bst) rm(bst)
# load the booster and make predictions...should be the same # load the booster and make predictions...should be the same
bst2 <- readRDS.lgb.Booster(file = model_file) bst2 <- readRDS(file = model_file)
preds2 <- predict(bst2, X) preds2 <- predict(bst2, X)
expect_identical(preds, preds2) expect_identical(preds, preds2)
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment