Unverified Commit e0ac6356 authored by Michael Mayer's avatar Michael Mayer Committed by GitHub
Browse files

[R-package] expose start_iteration to dump/save/lgb.model.dt.tree (#6398)

parent a70e8327
......@@ -416,7 +416,12 @@ Booster <- R6::R6Class(
},
# Save model
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
save_model = function(
filename
, num_iteration = NULL
, feature_importance_type = 0L
, start_iteration = 1L
) {
self$restore_handle()
......@@ -432,12 +437,18 @@ Booster <- R6::R6Class(
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, filename
, as.integer(start_iteration) - 1L # Turn to 0-based
)
return(invisible(self))
},
save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {
save_model_to_string = function(
num_iteration = NULL
, feature_importance_type = 0L
, as_char = TRUE
, start_iteration = 1L
) {
self$restore_handle()
......@@ -450,6 +461,7 @@ Booster <- R6::R6Class(
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, as.integer(start_iteration) - 1L # Turn to 0-based
)
if (as_char) {
......@@ -461,7 +473,9 @@ Booster <- R6::R6Class(
},
# Dump model in memory
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
dump_model = function(
num_iteration = NULL, feature_importance_type = 0L, start_iteration = 1L
) {
self$restore_handle()
......@@ -474,6 +488,7 @@ Booster <- R6::R6Class(
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, as.integer(start_iteration) - 1L # Turn to 0-based
)
return(model_str)
......@@ -1288,8 +1303,11 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' @title Save LightGBM model
#' @description Save LightGBM model
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param filename Saved filename
#' @param num_iteration Number of iterations to save, NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to save.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "save the fifth, sixth, and seventh tree"
#'
#' @return lgb.Booster
#'
......@@ -1322,7 +1340,9 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' lgb.save(model, tempfile(fileext = ".txt"))
#' }
#' @export
lgb.save <- function(booster, filename, num_iteration = NULL) {
lgb.save <- function(
booster, filename, num_iteration = NULL, start_iteration = 1L
) {
if (!.is_Booster(x = booster)) {
stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
......@@ -1338,6 +1358,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
invisible(booster$save_model(
filename = filename
, num_iteration = num_iteration
, start_iteration = start_iteration
))
)
......@@ -1347,7 +1368,10 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param num_iteration Number of iterations to be dumped. NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to dump.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "dump the fifth, sixth, and seventh tree"
#'
#' @return json format of model
#'
......@@ -1380,14 +1404,18 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' json_model <- lgb.dump(model)
#' }
#' @export
lgb.dump <- function(booster, num_iteration = NULL) {
lgb.dump <- function(booster, num_iteration = NULL, start_iteration = 1L) {
if (!.is_Booster(x = booster)) {
stop("lgb.dump: booster should be an ", sQuote("lgb.Booster"))
}
# Return booster at requested iteration
return(booster$dump_model(num_iteration = num_iteration))
return(
booster$dump_model(
num_iteration = num_iteration, start_iteration = start_iteration
)
)
}
......
#' @name lgb.model.dt.tree
#' @title Parse a LightGBM model json dump
#' @description Parse a LightGBM model json dump into a \code{data.table} structure.
#' @param model object of class \code{lgb.Booster}
#' @param num_iteration number of iterations you want to predict with. NULL or
#' <= 0 means use best iteration
#' @param model object of class \code{lgb.Booster}.
#' @param num_iteration Number of iterations to include. NULL or <= 0 means use best iteration.
#' @param start_iteration Index (1-based) of the first boosting round to include in the output.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "return information about the fifth, sixth, and seventh trees".
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
#'
......@@ -51,9 +53,15 @@
#' @importFrom data.table := rbindlist
#' @importFrom jsonlite fromJSON
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {
json_model <- lgb.dump(booster = model, num_iteration = num_iteration)
lgb.model.dt.tree <- function(
model, num_iteration = NULL, start_iteration = 1L
) {
json_model <- lgb.dump(
booster = model
, num_iteration = num_iteration
, start_iteration = start_iteration
)
parsed_json_model <- jsonlite::fromJSON(
txt = json_model
......@@ -84,7 +92,6 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
tree_dt[, split_feature := feature_names]
return(tree_dt)
}
......
......@@ -4,12 +4,16 @@
\alias{lgb.dump}
\title{Dump LightGBM model to json}
\usage{
lgb.dump(booster, num_iteration = NULL)
lgb.dump(booster, num_iteration = NULL, start_iteration = 1L)
}
\arguments{
\item{booster}{Object of class \code{lgb.Booster}}
\item{num_iteration}{number of iteration want to predict with, NULL or <= 0 means use best iteration}
\item{num_iteration}{Number of iterations to be dumped. NULL or <= 0 means use best iteration}
\item{start_iteration}{Index (1-based) of the first boosting round to dump.
For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
means "dump the fifth, sixth, and seventh tree"}
}
\value{
json format of model
......
......@@ -4,13 +4,16 @@
\alias{lgb.model.dt.tree}
\title{Parse a LightGBM model json dump}
\usage{
lgb.model.dt.tree(model, num_iteration = NULL)
lgb.model.dt.tree(model, num_iteration = NULL, start_iteration = 1L)
}
\arguments{
\item{model}{object of class \code{lgb.Booster}}
\item{model}{object of class \code{lgb.Booster}.}
\item{num_iteration}{number of iterations you want to predict with. NULL or
<= 0 means use best iteration}
\item{num_iteration}{Number of iterations to include. NULL or <= 0 means use best iteration.}
\item{start_iteration}{Index (1-based) of the first boosting round to include in the output.
For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
means "return information about the fifth, sixth, and seventh trees".}
}
\value{
A \code{data.table} with detailed information about model trees' nodes and leafs.
......
......@@ -4,14 +4,18 @@
\alias{lgb.save}
\title{Save LightGBM model}
\usage{
lgb.save(booster, filename, num_iteration = NULL)
lgb.save(booster, filename, num_iteration = NULL, start_iteration = 1L)
}
\arguments{
\item{booster}{Object of class \code{lgb.Booster}}
\item{filename}{saved filename}
\item{filename}{Saved filename}
\item{num_iteration}{number of iteration want to predict with, NULL or <= 0 means use best iteration}
\item{num_iteration}{Number of iterations to save, NULL or <= 0 means use best iteration}
\item{start_iteration}{Index (1-based) of the first boosting round to save.
For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
means "save the fifth, sixth, and seventh tree"}
}
\value{
lgb.Booster
......
......@@ -1093,11 +1093,12 @@ SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP filename) {
SEXP filename,
SEXP start_iteration) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
UNPROTECT(1);
return R_NilValue;
R_API_END();
......@@ -1105,20 +1106,22 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type) {
SEXP feature_importance_type,
SEXP start_iteration) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
// if the model string was larger than the initial buffer, call the function again, writing directly to the R object
if (out_len > buf_len) {
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
} else {
std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
}
......@@ -1129,7 +1132,8 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type) {
SEXP feature_importance_type,
SEXP start_iteration) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
......@@ -1137,13 +1141,14 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle,
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
}
model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
......@@ -1261,9 +1266,9 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterPredictForMatSingleRow_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R , 9},
{"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8},
{"LGBM_BoosterPredictForMatSingleRowFast_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R , 3},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 4},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 4},
{"LGBM_NullBoosterHandleError_R" , (DL_FUNC) &LGBM_NullBoosterHandleError_R , 0},
{"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0},
{"LGBM_GetMaxThreads_R" , (DL_FUNC) &LGBM_GetMaxThreads_R , 1},
......
......@@ -809,13 +809,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMatSingleRowFast_R(
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param filename file name
* \param start_iteration Starting iteration (0 based)
* \return R NULL value
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type,
SEXP filename
SEXP filename,
SEXP start_iteration
);
/*!
......@@ -823,12 +825,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param start_iteration Starting iteration (0 based)
* \return R character vector (length=1) with model string
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type
SEXP feature_importance_type,
SEXP start_iteration
);
/*!
......@@ -836,12 +840,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param start_iteration Index of starting iteration (0 based)
* \return R character vector (length=1) with model JSON
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
SEXP handle,
SEXP num_iteration,
SEXP feature_importance_type
SEXP feature_importance_type,
SEXP start_iteration
);
/*!
......
......@@ -1519,3 +1519,95 @@ test_that("LGBM_BoosterGetNumFeature_R returns correct outputs", {
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(iris) - 1L)
})
# Helper function that creates a fitted model with nrounds boosting rounds
.get_test_model <- function(nrounds) {
set.seed(1L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, params = list(objective = "binary", num_threads = .LGB_MAX_THREADS)
, nrounds = nrounds
, verbose = .LGB_VERBOSITY
)
return(bst)
}
# Simplified version of lgb.model.dt.tree()
.get_trees_from_dump <- function(x) {
parsed <- jsonlite::fromJSON(
txt = x
, simplifyVector = TRUE
, simplifyDataFrame = FALSE
, simplifyMatrix = FALSE
, flatten = FALSE
)
return(lapply(parsed$tree_info, FUN = .single_tree_parse))
}
test_that("num_iteration and start_iteration work for lgb.dump()", {
bst <- .get_test_model(5L)
first2 <- .get_trees_from_dump(lgb.dump(bst, num_iteration = 2L))
last3 <- .get_trees_from_dump(
lgb.dump(bst, num_iteration = 3L, start_iteration = 3L)
)
all5 <- .get_trees_from_dump(lgb.dump(bst))
too_many <- .get_trees_from_dump(lgb.dump(bst, num_iteration = 10L))
expect_equal(
data.table::rbindlist(c(first2, last3)), data.table::rbindlist(all5)
)
expect_equal(too_many, all5)
})
test_that("num_iteration and start_iteration work for lgb.save()", {
.get_n_trees <- function(x) {
return(length(.get_trees_from_dump(lgb.dump(x))))
}
.save_and_load <- function(bst, ...) {
model_file <- tempfile(fileext = ".model")
lgb.save(bst, model_file, ...)
return(lgb.load(model_file))
}
bst <- .get_test_model(5L)
n_first2 <- .get_n_trees(.save_and_load(bst, num_iteration = 2L))
n_last3 <- .get_n_trees(
.save_and_load(bst, num_iteration = 3L, start_iteration = 3L)
)
n_all5 <- .get_n_trees(.save_and_load(bst))
n_too_many <- .get_n_trees(.save_and_load(bst, num_iteration = 10L))
expect_equal(n_first2, 2L)
expect_equal(n_last3, 3L)
expect_equal(n_all5, 5L)
expect_equal(n_too_many, 5L)
})
test_that("num_iteration and start_iteration work for save_model_to_string()", {
.get_n_trees_from_string <- function(x) {
return(sum(gregexpr("Tree=", x, fixed = TRUE)[[1L]] > 0L))
}
bst <- .get_test_model(5L)
n_first2 <- .get_n_trees_from_string(
bst$save_model_to_string(num_iteration = 2L)
)
n_last3 <- .get_n_trees_from_string(
bst$save_model_to_string(num_iteration = 3L, start_iteration = 3L)
)
n_all5 <- .get_n_trees_from_string(bst$save_model_to_string())
n_too_many <- .get_n_trees_from_string(
bst$save_model_to_string(num_iteration = 10L)
)
expect_equal(n_first2, 2L)
expect_equal(n_last3, 3L)
expect_equal(n_all5, 5L)
expect_equal(n_too_many, 5L)
})
......@@ -156,3 +156,29 @@ for (model_name in names(models)) {
expect_true(all(counts > 1L & counts <= N))
})
}
test_that("num_iteration and start_iteration work as expected", {
set.seed(1L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, params = list(objective = "binary", num_threads = .LGB_MAX_THREADS)
, nrounds = 5L
, verbose = .LGB_VERBOSITY
)
first2 <- lgb.model.dt.tree(bst, num_iteration = 2L)
last3 <- lgb.model.dt.tree(bst, num_iteration = 3L, start_iteration = 3L)
all5 <- lgb.model.dt.tree(bst)
too_many <- lgb.model.dt.tree(bst, num_iteration = 10L)
expect_equal(data.table::rbindlist(list(first2, last3)), all5)
expect_equal(too_many, all5)
# Check tree indices
expect_equal(unique(first2[["tree_index"]]), 0L:1L)
expect_equal(unique(last3[["tree_index"]]), 2L:4L)
expect_equal(unique(all5[["tree_index"]]), 0L:4L)
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment