Unverified Commit 12b15273 authored by david-cortes's avatar david-cortes Committed by GitHub
Browse files

[R-package] enable saving Booster with saveRDS() and loading it with readRDS()...


[R-package] enable saving Booster with saveRDS() and loading it with readRDS() (fixes #4296) (#4685)

* idiomatic serialization

* linter

* linter, namespace

* comments, linter, fix failing test

* standardize error messages for null handles

* auto-restore handle in more functions

* linter

* missing declaration

* correct wrong signature

* fix docs

* Update R-package/R/lgb.train.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.drop_serialized.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.restore_handle.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.restore_handle.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.make_serializable.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* move 'restore_handle' from feature importance to dump method

* missing header

* move arguments order, update docs

* linter

* avoid leaving files in working directory

* add test for save_model=NULL

* missing comma

* Update R-package/R/lgb.restore_handle.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update R-package/src/lightgbm_R.cpp
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* change name of error function

* update comment

* restore old serialization functions but set as deprecated

* Update R-package/R/readRDS.lgb.Booster.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update R-package/R/saveRDS.lgb.Booster.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* update docs

* Update R-package/R/readRDS.lgb.Booster.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/saveRDS.lgb.Booster.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/tests/testthat/test_basic.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/readRDS.lgb.Booster.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* comments

* fix variable name

* restore serialization test for linear models

* Update R-package/R/lightgbm.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* update docs

* fix issues with null terminator
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent f54e32f7
......@@ -2,7 +2,7 @@
% Please edit documentation in R/saveRDS.lgb.Booster.R
\name{saveRDS.lgb.Booster}
\alias{saveRDS.lgb.Booster}
\title{saveRDS for \code{lgb.Booster} models}
\title{saveRDS for \code{lgb.Booster} models (DEPRECATED)}
\usage{
saveRDS.lgb.Booster(
object,
......@@ -38,8 +38,10 @@ compression to be used. Ignored if file is a connection.}
NULL invisibly.
}
\description{
Attempts to save a model using RDS. Has an additional parameter (\code{raw})
which decides whether to save the raw model or not.
Calls \code{saveRDS} on an \code{lgb.Booster} object, making it serializable before the call if
it isn't already.
\bold{This function throws a warning and will be removed in future versions.}
}
\examples{
\donttest{
......
......@@ -22,6 +22,7 @@
#include <memory>
#include <utility>
#include <vector>
#include <algorithm>
#define COL_MAJOR (0)
......@@ -60,6 +61,10 @@ SEXP wrapped_R_string(void *len) {
return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}
SEXP wrapped_R_raw(void *len) {
return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}
SEXP wrapped_Rf_mkChar(void *txt) {
return Rf_mkChar(reinterpret_cast<char*>(txt));
}
......@@ -75,6 +80,10 @@ SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}
SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}
......@@ -90,12 +99,17 @@ void _DatasetFinalizer(SEXP handle) {
LGBM_DatasetFree_R(handle);
}
SEXP LGBM_NullBoosterHandleError_R() {
Rf_error(
"Attempting to use a Booster which no longer exists and/or cannot be restored. "
"This can happen if you have called Booster$finalize() "
"or if this Booster was saved through saveRDS() using 'serializable=FALSE'.");
return R_NilValue;
}
void _AssertBoosterHandleNotNull(SEXP handle) {
if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
Rf_error(
"Attempting to use a Booster which no longer exists. "
"This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). "
"To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.");
LGBM_NullBoosterHandleError_R();
}
}
......@@ -462,13 +476,30 @@ SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
R_API_BEGIN();
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
SEXP temp = NULL;
int n_protected = 1;
int out_num_iterations = 0;
const char* model_str_ptr = CHAR(PROTECT(Rf_asChar(model_str)));
const char* model_str_ptr = nullptr;
switch (TYPEOF(model_str)) {
case RAWSXP: {
model_str_ptr = reinterpret_cast<const char*>(RAW(model_str));
break;
}
case CHARSXP: {
model_str_ptr = reinterpret_cast<const char*>(CHAR(model_str));
break;
}
case STRSXP: {
temp = PROTECT(STRING_ELT(model_str, 0));
n_protected++;
model_str_ptr = reinterpret_cast<const char*>(CHAR(temp));
}
}
BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
R_SetExternalPtrAddr(ret, handle);
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(2);
UNPROTECT(n_protected);
return ret;
R_API_END();
}
......@@ -828,20 +859,19 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
// if the model string was larger than the initial buffer, call the function again, writing directly to the R object
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
} else {
std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
}
model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
UNPROTECT(2);
return model_str;
R_API_END();
......@@ -936,6 +966,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
{"LGBM_NullBoosterHandleError_R" , (DL_FUNC) &LGBM_NullBoosterHandleError_R , 0},
{"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0},
{NULL, NULL, 0}
};
......
......@@ -20,6 +20,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_HandleIsNull_R(
SEXP handle
);
/*!
* \brief Throw a standardized error message when encountering a null Booster handle
* \return No return, will throw an error
*/
LIGHTGBM_C_EXPORT SEXP LGBM_NullBoosterHandleError_R();
// --- start Dataset interface
/*!
......
......@@ -293,6 +293,23 @@ test_that("lightgbm() performs evaluation on validation sets if they are provide
expect_true(abs(bst$record_evals[["valid2"]][["binary_error"]][["eval"]][[1L]] - 0.02226317) < TOLERANCE)
})
test_that("lightgbm() does not write model to disk if save_name=NULL", {
files_before <- list.files(getwd())
model <- lightgbm(
data = train$data
, label = train$label
, nrounds = 5L
, params = list(objective = "binary")
, verbose = 0L
, save_name = NULL
)
files_after <- list.files(getwd())
expect_equal(files_before, files_after)
})
context("training continuation")
......
......@@ -158,7 +158,7 @@ test_that("lgb.load() gives the expected error messages given different incorrec
# if given, model_str should be a string
expect_error({
lgb.load(model_str = c(4.0, 5.0, 6.0))
}, regexp = "model_str should be character")
}, regexp = "lgb.load: model_str should be a character/raw vector")
})
......@@ -841,6 +841,7 @@ test_that("Booster: method calls Booster with a null handle should raise an info
, valids = list(
train = dtrain
)
, serializable = FALSE
)
tmp_file <- tempfile(fileext = ".rds")
saveRDS(bst, tmp_file)
......@@ -875,7 +876,7 @@ test_that("Booster: method calls Booster with a null handle should raise an info
bst$rollback_one_iter()
})
.expect_booster_error({
bst$save()
bst$save_raw()
})
.expect_booster_error({
bst$save_model(filename = tempfile(fileext = ".model"))
......@@ -991,9 +992,9 @@ test_that("params (including dataset params) should be stored in .rds file for B
, train_set = dtrain
)
bst_file <- tempfile(fileext = ".rds")
saveRDS.lgb.Booster(bst, file = bst_file)
expect_warning(saveRDS.lgb.Booster(bst, file = bst_file))
bst_from_file <- readRDS.lgb.Booster(file = bst_file)
expect_warning(bst_from_file <- readRDS.lgb.Booster(file = bst_file))
expect_identical(
bst_from_file$params
, list(
......@@ -1005,6 +1006,91 @@ test_that("params (including dataset params) should be stored in .rds file for B
)
})
context("saveRDS and readRDS work on Booster")
test_that("params (including dataset params) should be stored in .rds file for Booster", {
data(agaricus.train, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
, params = list(
max_bin = 17L
)
)
params <- list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
)
bst <- Booster$new(
params = params
, train_set = dtrain
)
bst_file <- tempfile(fileext = ".rds")
saveRDS(bst, file = bst_file)
bst_from_file <- readRDS(file = bst_file)
expect_identical(
bst_from_file$params
, list(
objective = "binary"
, max_depth = 4L
, bagging_fraction = 0.8
, max_bin = 17L
)
)
})
test_that("Handle is automatically restored when calling predict", {
data(agaricus.train, package = "lightgbm")
bst <- lightgbm(agaricus.train$data, agaricus.train$label, nrounds = 5L, obj = "binary")
bst_file <- tempfile(fileext = ".rds")
saveRDS(bst, file = bst_file)
bst_from_file <- readRDS(file = bst_file)
pred_before <- predict(bst, agaricus.train$data)
pred_after <- predict(bst_from_file, agaricus.train$data)
expect_equal(pred_before, pred_after)
})
test_that("boosters with linear models at leaves work with saveRDS.lgb.Booster and readRDS.lgb.Booster", {
X <- matrix(rnorm(100L), ncol = 1L)
labels <- 2L * X + runif(nrow(X), 0L, 0.1)
dtrain <- lgb.Dataset(
data = X
, label = labels
)
params <- list(
objective = "regression"
, verbose = -1L
, metric = "mse"
, seed = 0L
, num_leaves = 2L
)
bst <- lgb.train(
data = dtrain
, nrounds = 10L
, params = params
)
expect_true(lgb.is.Booster(bst))
# save predictions, then write the model to a file and destroy it in R
preds <- predict(bst, X)
model_file <- tempfile(fileext = ".rds")
expect_warning(saveRDS.lgb.Booster(bst, file = model_file))
bst$finalize()
expect_null(bst$.__enclos_env__$private$handle)
rm(bst)
# load the booster and make predictions...should be the same
expect_warning({bst2 <- readRDS.lgb.Booster(file = model_file)})
preds2 <- predict(bst2, X)
expect_identical(preds, preds2)
})
test_that("boosters with linear models at leaves can be written to RDS and re-loaded successfully", {
X <- matrix(rnorm(100L), ncol = 1L)
labels <- 2L * X + runif(nrow(X), 0L, 0.1)
......@@ -1031,13 +1117,13 @@ test_that("boosters with linear models at leaves can be written to RDS and re-lo
# save predictions, then write the model to a file and destroy it in R
preds <- predict(bst, X)
model_file <- tempfile(fileext = ".rds")
saveRDS.lgb.Booster(bst, file = model_file)
saveRDS(bst, file = model_file)
bst$finalize()
expect_null(bst$.__enclos_env__$private$handle)
rm(bst)
# load the booster and make predictions...should be the same
bst2 <- readRDS.lgb.Booster(file = model_file)
bst2 <- readRDS(file = model_file)
preds2 <- predict(bst2, X)
expect_identical(preds, preds2)
})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment