Commit beb5fc5e authored by JesseLimtiaco's avatar JesseLimtiaco Committed by Guolin Ke
Browse files

Adding ability to load model from string in R (#472)

* Update lgb.Booster.R

Added method to call LGBM_BoosterLoadModelFromString_R if model_str is provided for initialization, added option to load from model_str in lgb.load

* Update lightgbm_R.cpp

Adding LGBM_BoosterLoadModelFromString_R

* Update lightgbm_R.cpp

Added LGBM_BoosterSaveModelToString_R

* Update lightgbm_R.cpp

* Update lgb.Booster.R

Added save_model_to_string method

* Update lgb.Booster.R

Implemented @Laurae2 comments

* Update lgb.Booster.R

* Update lightgbm_R.h

Added load and save model from/to string exports
parent 2a64bfee
......@@ -24,6 +24,7 @@ Booster <- R6Class(
initialize = function(params = list(),
train_set = NULL,
modelfile = NULL,
model_str = NULL,
...) {
# Create parameters and handle
......@@ -73,10 +74,22 @@ Booster <- R6Class(
ret = handle,
lgb.c_str(modelfile))
} else if (!is.null(model_str)) {
# Do we have a model_str as character?
if (!is.character(model_str)) {
stop("lgb.Booster: Can only use a string as model_str")
}
# Create booster from model
handle <- lgb.call("LGBM_BoosterLoadModelFromString_R",
ret = handle,
lgb.c_str(model_str))
} else {
# Booster non existent
stop("lgb.Booster: Need at least either training dataset or model file to create booster instance")
stop("lgb.Booster: Need at least either training dataset, model file, or model_str to create booster instance")
}
......@@ -343,6 +356,21 @@ Booster <- R6Class(
return(self)
},
# Save model to string
save_model_to_string = function(num_iteration = NULL) {
# Check if number of iteration is non existent
if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}
# Return model string
return(lgb.call.return.str("LGBM_BoosterSaveModelToString_R",
private$handle,
as.integer(num_iteration)))
},
# Dump model in memory
dump_model = function(num_iteration = NULL) {
......@@ -645,9 +673,12 @@ predict.lgb.Booster <- function(object, data,
#' Load LightGBM model
#'
#' Load LightGBM model from saved model file
#' Load LightGBM model from saved model file or string
#' Load LightGBM takes in either a file path or model string
#' If both are provided, Load will default to loading from file
#'
#' @param filename path of model file
#' @param model_str a str containing the model
#'
#' @return booster
#'
......@@ -671,19 +702,32 @@ predict.lgb.Booster <- function(object, data,
#' early_stopping_rounds = 10)
#' lgb.save(model, "model.txt")
#' load_booster <- lgb.load("model.txt")
#' load_booster_from_str <- lgb.load(model$raw)
#' }
#'
#' @rdname lgb.load
#' @export
lgb.load <- function(filename){
lgb.load <- function(filename = NULL, model_str = NULL){
# Check if file name is character or not
if (!is.character(filename)) {
if (is.null(filename) && is.null(model_str)) {
stop("lgb.load: either filename or model_str must be given")
}
# Load from filename
if (!is.null(filename) && !is.character(filename)) {
stop("lgb.load: filename should be character")
}
# Return new booster
Booster$new(modelfile = filename)
if (!is.null(filename) && !file.exists(filename)) stop("lgb.load: file does not exist for supplied filename")
if (!is.null(filename)) return(Booster$new(modelfile = filename))
# Load from model_str
if (!is.null(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be character")
}
# Return new booster
if (!is.null(model_str)) return(Booster$new(model_str = model_str))
}
......
......@@ -324,6 +324,18 @@ LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
R_API_END();
}
LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
LGBM_SE out,
LGBM_SE call_state) {
R_API_BEGIN();
int out_num_iterations = 0;
BoosterHandle handle;
CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
R_SET_PTR(out, handle);
R_API_END();
}
LGBM_SE LGBM_BoosterMerge_R(LGBM_SE handle,
LGBM_SE other_handle,
LGBM_SE call_state) {
......@@ -579,6 +591,25 @@ LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
R_API_END();
}
LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
LGBM_SE call_state) {
R_API_BEGIN();
int out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
if (out_len < R_AS_INT(buffer_len)) {
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
} else {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
}
R_API_END();
}
LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE buffer_len,
......
......@@ -224,6 +224,16 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
LGBM_SE out,
LGBM_SE call_state);
/*!
* \brief load an existing boosting from model_str
* \param model_str string containing the model
* \param out handle of created Booster
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
LGBM_SE out,
LGBM_SE call_state);
/*!
* \brief Merge model in two boosters to first handle
* \param handle handle, will merge other handle to this
......@@ -468,6 +478,20 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
LGBM_SE filename,
LGBM_SE call_state);
/*!
* \brief create string containing model
* \param handle handle
* \param num_iteration, <= 0 means save all
* \param out_str string of model
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
LGBM_SE call_state);
/*!
* \brief dump model to json
* \param handle handle
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment