"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "7825084f160f3f8093cef3fdd13d1e49351e9e13"
Unverified Commit e0ac6356 authored by Michael Mayer's avatar Michael Mayer Committed by GitHub
Browse files

[R-package] expose start_iteration to dump/save/lgb.model.dt.tree (#6398)

parent a70e8327
...@@ -416,7 +416,12 @@ Booster <- R6::R6Class( ...@@ -416,7 +416,12 @@ Booster <- R6::R6Class(
}, },
# Save model # Save model
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) { save_model = function(
filename
, num_iteration = NULL
, feature_importance_type = 0L
, start_iteration = 1L
) {
self$restore_handle() self$restore_handle()
...@@ -432,12 +437,18 @@ Booster <- R6::R6Class( ...@@ -432,12 +437,18 @@ Booster <- R6::R6Class(
, as.integer(num_iteration) , as.integer(num_iteration)
, as.integer(feature_importance_type) , as.integer(feature_importance_type)
, filename , filename
, as.integer(start_iteration) - 1L # Turn to 0-based
) )
return(invisible(self)) return(invisible(self))
}, },
save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) { save_model_to_string = function(
num_iteration = NULL
, feature_importance_type = 0L
, as_char = TRUE
, start_iteration = 1L
) {
self$restore_handle() self$restore_handle()
...@@ -450,6 +461,7 @@ Booster <- R6::R6Class( ...@@ -450,6 +461,7 @@ Booster <- R6::R6Class(
, private$handle , private$handle
, as.integer(num_iteration) , as.integer(num_iteration)
, as.integer(feature_importance_type) , as.integer(feature_importance_type)
, as.integer(start_iteration) - 1L # Turn to 0-based
) )
if (as_char) { if (as_char) {
...@@ -461,7 +473,9 @@ Booster <- R6::R6Class( ...@@ -461,7 +473,9 @@ Booster <- R6::R6Class(
}, },
# Dump model in memory # Dump model in memory
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) { dump_model = function(
num_iteration = NULL, feature_importance_type = 0L, start_iteration = 1L
) {
self$restore_handle() self$restore_handle()
...@@ -474,6 +488,7 @@ Booster <- R6::R6Class( ...@@ -474,6 +488,7 @@ Booster <- R6::R6Class(
, private$handle , private$handle
, as.integer(num_iteration) , as.integer(num_iteration)
, as.integer(feature_importance_type) , as.integer(feature_importance_type)
, as.integer(start_iteration) - 1L # Turn to 0-based
) )
return(model_str) return(model_str)
...@@ -1288,8 +1303,11 @@ lgb.load <- function(filename = NULL, model_str = NULL) { ...@@ -1288,8 +1303,11 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' @title Save LightGBM model #' @title Save LightGBM model
#' @description Save LightGBM model #' @description Save LightGBM model
#' @param booster Object of class \code{lgb.Booster} #' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename #' @param filename Saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration #' @param num_iteration Number of iterations to save, NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to save.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "save the fifth, sixth, and seventh tree"
#' #'
#' @return lgb.Booster #' @return lgb.Booster
#' #'
...@@ -1322,7 +1340,9 @@ lgb.load <- function(filename = NULL, model_str = NULL) { ...@@ -1322,7 +1340,9 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' lgb.save(model, tempfile(fileext = ".txt")) #' lgb.save(model, tempfile(fileext = ".txt"))
#' } #' }
#' @export #' @export
lgb.save <- function(booster, filename, num_iteration = NULL) { lgb.save <- function(
booster, filename, num_iteration = NULL, start_iteration = 1L
) {
if (!.is_Booster(x = booster)) { if (!.is_Booster(x = booster)) {
stop("lgb.save: booster should be an ", sQuote("lgb.Booster")) stop("lgb.save: booster should be an ", sQuote("lgb.Booster"))
...@@ -1338,6 +1358,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) { ...@@ -1338,6 +1358,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
invisible(booster$save_model( invisible(booster$save_model(
filename = filename filename = filename
, num_iteration = num_iteration , num_iteration = num_iteration
, start_iteration = start_iteration
)) ))
) )
...@@ -1347,7 +1368,10 @@ lgb.save <- function(booster, filename, num_iteration = NULL) { ...@@ -1347,7 +1368,10 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' @title Dump LightGBM model to json #' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json #' @description Dump LightGBM model to json
#' @param booster Object of class \code{lgb.Booster} #' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration #' @param num_iteration Number of iterations to be dumped. NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to dump.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "dump the fifth, sixth, and seventh tree"
#' #'
#' @return json format of model #' @return json format of model
#' #'
...@@ -1380,14 +1404,18 @@ lgb.save <- function(booster, filename, num_iteration = NULL) { ...@@ -1380,14 +1404,18 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' json_model <- lgb.dump(model) #' json_model <- lgb.dump(model)
#' } #' }
#' @export #' @export
lgb.dump <- function(booster, num_iteration = NULL) { lgb.dump <- function(booster, num_iteration = NULL, start_iteration = 1L) {
if (!.is_Booster(x = booster)) { if (!.is_Booster(x = booster)) {
stop("lgb.dump: booster should be an ", sQuote("lgb.Booster")) stop("lgb.dump: booster should be an ", sQuote("lgb.Booster"))
} }
# Return booster at requested iteration # Return booster at requested iteration
return(booster$dump_model(num_iteration = num_iteration)) return(
booster$dump_model(
num_iteration = num_iteration, start_iteration = start_iteration
)
)
} }
......
#' @name lgb.model.dt.tree #' @name lgb.model.dt.tree
#' @title Parse a LightGBM model json dump #' @title Parse a LightGBM model json dump
#' @description Parse a LightGBM model json dump into a \code{data.table} structure. #' @description Parse a LightGBM model json dump into a \code{data.table} structure.
#' @param model object of class \code{lgb.Booster} #' @param model object of class \code{lgb.Booster}.
#' @param num_iteration number of iterations you want to predict with. NULL or #' @param num_iteration Number of iterations to include. NULL or <= 0 means use best iteration.
#' <= 0 means use best iteration #' @param start_iteration Index (1-based) of the first boosting round to include in the output.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "return information about the fifth, sixth, and seventh trees".
#' @return #' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs. #' A \code{data.table} with detailed information about model trees' nodes and leafs.
#' #'
...@@ -51,9 +53,15 @@ ...@@ -51,9 +53,15 @@
#' @importFrom data.table := rbindlist #' @importFrom data.table := rbindlist
#' @importFrom jsonlite fromJSON #' @importFrom jsonlite fromJSON
#' @export #' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) { lgb.model.dt.tree <- function(
model, num_iteration = NULL, start_iteration = 1L
json_model <- lgb.dump(booster = model, num_iteration = num_iteration) ) {
json_model <- lgb.dump(
booster = model
, num_iteration = num_iteration
, start_iteration = start_iteration
)
parsed_json_model <- jsonlite::fromJSON( parsed_json_model <- jsonlite::fromJSON(
txt = json_model txt = json_model
...@@ -84,7 +92,6 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { ...@@ -84,7 +92,6 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
tree_dt[, split_feature := feature_names] tree_dt[, split_feature := feature_names]
return(tree_dt) return(tree_dt)
} }
......
...@@ -4,12 +4,16 @@ ...@@ -4,12 +4,16 @@
\alias{lgb.dump} \alias{lgb.dump}
\title{Dump LightGBM model to json} \title{Dump LightGBM model to json}
\usage{ \usage{
lgb.dump(booster, num_iteration = NULL) lgb.dump(booster, num_iteration = NULL, start_iteration = 1L)
} }
\arguments{ \arguments{
\item{booster}{Object of class \code{lgb.Booster}} \item{booster}{Object of class \code{lgb.Booster}}
\item{num_iteration}{number of iteration want to predict with, NULL or <= 0 means use best iteration} \item{num_iteration}{Number of iterations to be dumped. NULL or <= 0 means use best iteration}
\item{start_iteration}{Index (1-based) of the first boosting round to dump.
For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
means "dump the fifth, sixth, and seventh tree"}
} }
\value{ \value{
json format of model json format of model
......
...@@ -4,13 +4,16 @@ ...@@ -4,13 +4,16 @@
\alias{lgb.model.dt.tree} \alias{lgb.model.dt.tree}
\title{Parse a LightGBM model json dump} \title{Parse a LightGBM model json dump}
\usage{ \usage{
lgb.model.dt.tree(model, num_iteration = NULL) lgb.model.dt.tree(model, num_iteration = NULL, start_iteration = 1L)
} }
\arguments{ \arguments{
\item{model}{object of class \code{lgb.Booster}} \item{model}{object of class \code{lgb.Booster}.}
\item{num_iteration}{number of iterations you want to predict with. NULL or \item{num_iteration}{Number of iterations to include. NULL or <= 0 means use best iteration.}
<= 0 means use best iteration}
\item{start_iteration}{Index (1-based) of the first boosting round to include in the output.
For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
means "return information about the fifth, sixth, and seventh trees".}
} }
\value{ \value{
A \code{data.table} with detailed information about model trees' nodes and leafs. A \code{data.table} with detailed information about model trees' nodes and leafs.
......
...@@ -4,14 +4,18 @@ ...@@ -4,14 +4,18 @@
\alias{lgb.save} \alias{lgb.save}
\title{Save LightGBM model} \title{Save LightGBM model}
\usage{ \usage{
lgb.save(booster, filename, num_iteration = NULL) lgb.save(booster, filename, num_iteration = NULL, start_iteration = 1L)
} }
\arguments{ \arguments{
\item{booster}{Object of class \code{lgb.Booster}} \item{booster}{Object of class \code{lgb.Booster}}
\item{filename}{saved filename} \item{filename}{Saved filename}
\item{num_iteration}{number of iteration want to predict with, NULL or <= 0 means use best iteration} \item{num_iteration}{Number of iterations to save, NULL or <= 0 means use best iteration}
\item{start_iteration}{Index (1-based) of the first boosting round to save.
For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
means "save the fifth, sixth, and seventh tree"}
} }
\value{ \value{
lgb.Booster lgb.Booster
......
...@@ -1093,11 +1093,12 @@ SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig, ...@@ -1093,11 +1093,12 @@ SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,
SEXP LGBM_BoosterSaveModel_R(SEXP handle, SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP filename) { SEXP filename,
SEXP start_iteration) {
R_API_BEGIN(); R_API_BEGIN();
_AssertBoosterHandleNotNull(handle); _AssertBoosterHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr)); CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END(); R_API_END();
...@@ -1105,20 +1106,22 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle, ...@@ -1105,20 +1106,22 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type) { SEXP feature_importance_type,
SEXP start_iteration) {
SEXP cont_token = PROTECT(R_MakeUnwindCont()); SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN(); R_API_BEGIN();
_AssertBoosterHandleNotNull(handle); _AssertBoosterHandleNotNull(handle);
int64_t out_len = 0; int64_t out_len = 0;
int64_t buf_len = 1024 * 1024; int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration); int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type); int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token)); SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
// if the model string was larger than the initial buffer, call the function again, writing directly to the R object // if the model string was larger than the initial buffer, call the function again, writing directly to the R object
if (out_len > buf_len) { if (out_len > buf_len) {
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str)))); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
} else { } else {
std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str))); std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
} }
...@@ -1129,7 +1132,8 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, ...@@ -1129,7 +1132,8 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP LGBM_BoosterDumpModel_R(SEXP handle, SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type) { SEXP feature_importance_type,
SEXP start_iteration) {
SEXP cont_token = PROTECT(R_MakeUnwindCont()); SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN(); R_API_BEGIN();
_AssertBoosterHandleNotNull(handle); _AssertBoosterHandleNotNull(handle);
...@@ -1137,13 +1141,14 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle, ...@@ -1137,13 +1141,14 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle,
int64_t out_len = 0; int64_t out_len = 0;
int64_t buf_len = 1024 * 1024; int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration); int num_iter = Rf_asInteger(num_iteration);
int start_iter = Rf_asInteger(start_iteration);
int importance_type = Rf_asInteger(feature_importance_type); int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) { if (out_len > buf_len) {
inner_char_buf.resize(out_len); inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
} }
model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token)); model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token)); SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
...@@ -1261,9 +1266,9 @@ static const R_CallMethodDef CallEntries[] = { ...@@ -1261,9 +1266,9 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterPredictForMatSingleRow_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R , 9}, {"LGBM_BoosterPredictForMatSingleRow_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R , 9},
{"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8}, {"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8},
{"LGBM_BoosterPredictForMatSingleRowFast_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R , 3}, {"LGBM_BoosterPredictForMatSingleRowFast_R" , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R , 3},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4}, {"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 3}, {"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 4},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 3}, {"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 4},
{"LGBM_NullBoosterHandleError_R" , (DL_FUNC) &LGBM_NullBoosterHandleError_R , 0}, {"LGBM_NullBoosterHandleError_R" , (DL_FUNC) &LGBM_NullBoosterHandleError_R , 0},
{"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0}, {"LGBM_DumpParamAliases_R" , (DL_FUNC) &LGBM_DumpParamAliases_R , 0},
{"LGBM_GetMaxThreads_R" , (DL_FUNC) &LGBM_GetMaxThreads_R , 1}, {"LGBM_GetMaxThreads_R" , (DL_FUNC) &LGBM_GetMaxThreads_R , 1},
......
...@@ -809,13 +809,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMatSingleRowFast_R( ...@@ -809,13 +809,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMatSingleRowFast_R(
* \param num_iteration, <= 0 means save all * \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain * \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param filename file name * \param filename file name
* \param start_iteration Starting iteration (0 based)
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
SEXP handle, SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP filename SEXP filename,
SEXP start_iteration
); );
/*! /*!
...@@ -823,12 +825,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R( ...@@ -823,12 +825,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
* \param handle Booster handle * \param handle Booster handle
* \param num_iteration, <= 0 means save all * \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain * \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param start_iteration Starting iteration (0 based)
* \return R character vector (length=1) with model string * \return R character vector (length=1) with model string
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
SEXP handle, SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type SEXP feature_importance_type,
SEXP start_iteration
); );
/*! /*!
...@@ -836,12 +840,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R( ...@@ -836,12 +840,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
* \param handle Booster handle * \param handle Booster handle
* \param num_iteration, <= 0 means save all * \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain * \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param start_iteration Index of starting iteration (0 based)
* \return R character vector (length=1) with model JSON * \return R character vector (length=1) with model JSON
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
SEXP handle, SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type SEXP feature_importance_type,
SEXP start_iteration
); );
/*! /*!
......
...@@ -1519,3 +1519,95 @@ test_that("LGBM_BoosterGetNumFeature_R returns correct outputs", { ...@@ -1519,3 +1519,95 @@ test_that("LGBM_BoosterGetNumFeature_R returns correct outputs", {
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle) ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(iris) - 1L) expect_equal(ncols, ncol(iris) - 1L)
}) })
# Helper function that creates a fitted model with nrounds boosting rounds
.get_test_model <- function(nrounds) {
set.seed(1L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, params = list(objective = "binary", num_threads = .LGB_MAX_THREADS)
, nrounds = nrounds
, verbose = .LGB_VERBOSITY
)
return(bst)
}
# Simplified version of lgb.model.dt.tree()
.get_trees_from_dump <- function(x) {
parsed <- jsonlite::fromJSON(
txt = x
, simplifyVector = TRUE
, simplifyDataFrame = FALSE
, simplifyMatrix = FALSE
, flatten = FALSE
)
return(lapply(parsed$tree_info, FUN = .single_tree_parse))
}
test_that("num_iteration and start_iteration work for lgb.dump()", {
bst <- .get_test_model(5L)
first2 <- .get_trees_from_dump(lgb.dump(bst, num_iteration = 2L))
last3 <- .get_trees_from_dump(
lgb.dump(bst, num_iteration = 3L, start_iteration = 3L)
)
all5 <- .get_trees_from_dump(lgb.dump(bst))
too_many <- .get_trees_from_dump(lgb.dump(bst, num_iteration = 10L))
expect_equal(
data.table::rbindlist(c(first2, last3)), data.table::rbindlist(all5)
)
expect_equal(too_many, all5)
})
test_that("num_iteration and start_iteration work for lgb.save()", {
.get_n_trees <- function(x) {
return(length(.get_trees_from_dump(lgb.dump(x))))
}
.save_and_load <- function(bst, ...) {
model_file <- tempfile(fileext = ".model")
lgb.save(bst, model_file, ...)
return(lgb.load(model_file))
}
bst <- .get_test_model(5L)
n_first2 <- .get_n_trees(.save_and_load(bst, num_iteration = 2L))
n_last3 <- .get_n_trees(
.save_and_load(bst, num_iteration = 3L, start_iteration = 3L)
)
n_all5 <- .get_n_trees(.save_and_load(bst))
n_too_many <- .get_n_trees(.save_and_load(bst, num_iteration = 10L))
expect_equal(n_first2, 2L)
expect_equal(n_last3, 3L)
expect_equal(n_all5, 5L)
expect_equal(n_too_many, 5L)
})
test_that("num_iteration and start_iteration work for save_model_to_string()", {
.get_n_trees_from_string <- function(x) {
return(sum(gregexpr("Tree=", x, fixed = TRUE)[[1L]] > 0L))
}
bst <- .get_test_model(5L)
n_first2 <- .get_n_trees_from_string(
bst$save_model_to_string(num_iteration = 2L)
)
n_last3 <- .get_n_trees_from_string(
bst$save_model_to_string(num_iteration = 3L, start_iteration = 3L)
)
n_all5 <- .get_n_trees_from_string(bst$save_model_to_string())
n_too_many <- .get_n_trees_from_string(
bst$save_model_to_string(num_iteration = 10L)
)
expect_equal(n_first2, 2L)
expect_equal(n_last3, 3L)
expect_equal(n_all5, 5L)
expect_equal(n_too_many, 5L)
})
...@@ -156,3 +156,29 @@ for (model_name in names(models)) { ...@@ -156,3 +156,29 @@ for (model_name in names(models)) {
expect_true(all(counts > 1L & counts <= N)) expect_true(all(counts > 1L & counts <= N))
}) })
} }
test_that("num_iteration and start_iteration work as expected", {
set.seed(1L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, params = list(objective = "binary", num_threads = .LGB_MAX_THREADS)
, nrounds = 5L
, verbose = .LGB_VERBOSITY
)
first2 <- lgb.model.dt.tree(bst, num_iteration = 2L)
last3 <- lgb.model.dt.tree(bst, num_iteration = 3L, start_iteration = 3L)
all5 <- lgb.model.dt.tree(bst)
too_many <- lgb.model.dt.tree(bst, num_iteration = 10L)
expect_equal(data.table::rbindlist(list(first2, last3)), all5)
expect_equal(too_many, all5)
# Check tree indices
expect_equal(unique(first2[["tree_index"]]), 0L:1L)
expect_equal(unique(last3[["tree_index"]]), 2L:4L)
expect_equal(unique(all5[["tree_index"]]), 0L:4L)
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment