Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
e0ac6356
Unverified
Commit
e0ac6356
authored
May 16, 2024
by
Michael Mayer
Committed by
GitHub
May 15, 2024
Browse files
[R-package] expose start_iteration to dump/save/lgb.model.dt.tree (#6398)
parent
a70e8327
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
214 additions
and
39 deletions
+214
-39
R-package/R/lgb.Booster.R
R-package/R/lgb.Booster.R
+37
-9
R-package/R/lgb.model.dt.tree.R
R-package/R/lgb.model.dt.tree.R
+14
-7
R-package/man/lgb.dump.Rd
R-package/man/lgb.dump.Rd
+6
-2
R-package/man/lgb.model.dt.tree.Rd
R-package/man/lgb.model.dt.tree.Rd
+7
-4
R-package/man/lgb.save.Rd
R-package/man/lgb.save.Rd
+7
-3
R-package/src/lightgbm_R.cpp
R-package/src/lightgbm_R.cpp
+16
-11
R-package/src/lightgbm_R.h
R-package/src/lightgbm_R.h
+9
-3
R-package/tests/testthat/test_lgb.Booster.R
R-package/tests/testthat/test_lgb.Booster.R
+92
-0
R-package/tests/testthat/test_lgb.model.dt.tree.R
R-package/tests/testthat/test_lgb.model.dt.tree.R
+26
-0
No files found.
R-package/R/lgb.Booster.R
View file @
e0ac6356
...
...
@@ -416,7 +416,12 @@ Booster <- R6::R6Class(
},
# Save model
save_model
=
function
(
filename
,
num_iteration
=
NULL
,
feature_importance_type
=
0L
)
{
save_model
=
function
(
filename
,
num_iteration
=
NULL
,
feature_importance_type
=
0L
,
start_iteration
=
1L
)
{
self
$
restore_handle
()
...
...
@@ -432,12 +437,18 @@ Booster <- R6::R6Class(
,
as.integer
(
num_iteration
)
,
as.integer
(
feature_importance_type
)
,
filename
,
as.integer
(
start_iteration
)
-
1L
# Turn to 0-based
)
return
(
invisible
(
self
))
},
save_model_to_string
=
function
(
num_iteration
=
NULL
,
feature_importance_type
=
0L
,
as_char
=
TRUE
)
{
save_model_to_string
=
function
(
num_iteration
=
NULL
,
feature_importance_type
=
0L
,
as_char
=
TRUE
,
start_iteration
=
1L
)
{
self
$
restore_handle
()
...
...
@@ -450,6 +461,7 @@ Booster <- R6::R6Class(
,
private
$
handle
,
as.integer
(
num_iteration
)
,
as.integer
(
feature_importance_type
)
,
as.integer
(
start_iteration
)
-
1L
# Turn to 0-based
)
if
(
as_char
)
{
...
...
@@ -461,7 +473,9 @@ Booster <- R6::R6Class(
},
# Dump model in memory
dump_model
=
function
(
num_iteration
=
NULL
,
feature_importance_type
=
0L
)
{
dump_model
=
function
(
num_iteration
=
NULL
,
feature_importance_type
=
0L
,
start_iteration
=
1L
)
{
self
$
restore_handle
()
...
...
@@ -474,6 +488,7 @@ Booster <- R6::R6Class(
,
private
$
handle
,
as.integer
(
num_iteration
)
,
as.integer
(
feature_importance_type
)
,
as.integer
(
start_iteration
)
-
1L
# Turn to 0-based
)
return
(
model_str
)
...
...
@@ -1288,8 +1303,11 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' @title Save LightGBM model
#' @description Save LightGBM model
#' @param booster Object of class \code{lgb.Booster}
#' @param filename saved filename
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param filename Saved filename
#' @param num_iteration Number of iterations to save, NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to save.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "save the fifth, sixth, and seventh tree"
#'
#' @return lgb.Booster
#'
...
...
@@ -1322,7 +1340,9 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' lgb.save(model, tempfile(fileext = ".txt"))
#' }
#' @export
lgb.save
<-
function
(
booster
,
filename
,
num_iteration
=
NULL
)
{
lgb.save
<-
function
(
booster
,
filename
,
num_iteration
=
NULL
,
start_iteration
=
1L
)
{
if
(
!
.is_Booster
(
x
=
booster
))
{
stop
(
"lgb.save: booster should be an "
,
sQuote
(
"lgb.Booster"
))
...
...
@@ -1338,6 +1358,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
invisible
(
booster
$
save_model
(
filename
=
filename
,
num_iteration
=
num_iteration
,
start_iteration
=
start_iteration
))
)
...
...
@@ -1347,7 +1368,10 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' @title Dump LightGBM model to json
#' @description Dump LightGBM model to json
#' @param booster Object of class \code{lgb.Booster}
#' @param num_iteration number of iteration want to predict with, NULL or <= 0 means use best iteration
#' @param num_iteration Number of iterations to be dumped. NULL or <= 0 means use best iteration
#' @param start_iteration Index (1-based) of the first boosting round to dump.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "dump the fifth, sixth, and seventh tree"
#'
#' @return json format of model
#'
...
...
@@ -1380,14 +1404,18 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' json_model <- lgb.dump(model)
#' }
#' @export
lgb.dump
<-
function
(
booster
,
num_iteration
=
NULL
)
{
lgb.dump
<-
function
(
booster
,
num_iteration
=
NULL
,
start_iteration
=
1L
)
{
if
(
!
.is_Booster
(
x
=
booster
))
{
stop
(
"lgb.dump: booster should be an "
,
sQuote
(
"lgb.Booster"
))
}
# Return booster at requested iteration
return
(
booster
$
dump_model
(
num_iteration
=
num_iteration
))
return
(
booster
$
dump_model
(
num_iteration
=
num_iteration
,
start_iteration
=
start_iteration
)
)
}
...
...
R-package/R/lgb.model.dt.tree.R
View file @
e0ac6356
#' @name lgb.model.dt.tree
#' @title Parse a LightGBM model json dump
#' @description Parse a LightGBM model json dump into a \code{data.table} structure.
#' @param model object of class \code{lgb.Booster}
#' @param num_iteration number of iterations you want to predict with. NULL or
#' <= 0 means use best iteration
#' @param model object of class \code{lgb.Booster}.
#' @param num_iteration Number of iterations to include. NULL or <= 0 means use best iteration.
#' @param start_iteration Index (1-based) of the first boosting round to include in the output.
#' For example, passing \code{start_iteration=5, num_iteration=3} for a regression model
#' means "return information about the fifth, sixth, and seventh trees".
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
#'
...
...
@@ -51,9 +53,15 @@
#' @importFrom data.table := rbindlist
#' @importFrom jsonlite fromJSON
#' @export
lgb.model.dt.tree
<-
function
(
model
,
num_iteration
=
NULL
)
{
json_model
<-
lgb.dump
(
booster
=
model
,
num_iteration
=
num_iteration
)
lgb.model.dt.tree
<-
function
(
model
,
num_iteration
=
NULL
,
start_iteration
=
1L
)
{
json_model
<-
lgb.dump
(
booster
=
model
,
num_iteration
=
num_iteration
,
start_iteration
=
start_iteration
)
parsed_json_model
<-
jsonlite
::
fromJSON
(
txt
=
json_model
...
...
@@ -84,7 +92,6 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
tree_dt
[,
split_feature
:=
feature_names
]
return
(
tree_dt
)
}
...
...
R-package/man/lgb.dump.Rd
View file @
e0ac6356
...
...
@@ -4,12 +4,16 @@
\
alias
{
lgb
.
dump
}
\
title
{
Dump
LightGBM
model
to
json
}
\
usage
{
lgb
.
dump
(
booster
,
num_iteration
=
NULL
)
lgb
.
dump
(
booster
,
num_iteration
=
NULL
,
start_iteration
=
1L
)
}
\
arguments
{
\
item
{
booster
}{
Object
of
class
\
code
{
lgb
.
Booster
}}
\
item
{
num_iteration
}{
number
of
iteration
want
to
predict
with
,
NULL
or
<=
0
means
use
best
iteration
}
\
item
{
num_iteration
}{
Number
of
iterations
to
be
dumped
.
NULL
or
<=
0
means
use
best
iteration
}
\
item
{
start_iteration
}{
Index
(
1
-
based
)
of
the
first
boosting
round
to
dump
.
For
example
,
passing
\
code
{
start_iteration
=
5
,
num_iteration
=
3
}
for
a
regression
model
means
"dump the fifth, sixth, and seventh tree"
}
}
\
value
{
json
format
of
model
...
...
R-package/man/lgb.model.dt.tree.Rd
View file @
e0ac6356
...
...
@@ -4,13 +4,16 @@
\
alias
{
lgb
.
model
.
dt
.
tree
}
\
title
{
Parse
a
LightGBM
model
json
dump
}
\
usage
{
lgb
.
model
.
dt
.
tree
(
model
,
num_iteration
=
NULL
)
lgb
.
model
.
dt
.
tree
(
model
,
num_iteration
=
NULL
,
start_iteration
=
1L
)
}
\
arguments
{
\
item
{
model
}{
object
of
class
\
code
{
lgb
.
Booster
}}
\
item
{
model
}{
object
of
class
\
code
{
lgb
.
Booster
}
.
}
\
item
{
num_iteration
}{
number
of
iterations
you
want
to
predict
with
.
NULL
or
<=
0
means
use
best
iteration
}
\
item
{
num_iteration
}{
Number
of
iterations
to
include
.
NULL
or
<=
0
means
use
best
iteration
.}
\
item
{
start_iteration
}{
Index
(
1
-
based
)
of
the
first
boosting
round
to
include
in
the
output
.
For
example
,
passing
\
code
{
start_iteration
=
5
,
num_iteration
=
3
}
for
a
regression
model
means
"return information about the fifth, sixth, and seventh trees"
.}
}
\
value
{
A
\
code
{
data
.
table
}
with
detailed
information
about
model
trees
' nodes and leafs.
...
...
R-package/man/lgb.save.Rd
View file @
e0ac6356
...
...
@@ -4,14 +4,18 @@
\
alias
{
lgb
.
save
}
\
title
{
Save
LightGBM
model
}
\
usage
{
lgb
.
save
(
booster
,
filename
,
num_iteration
=
NULL
)
lgb
.
save
(
booster
,
filename
,
num_iteration
=
NULL
,
start_iteration
=
1L
)
}
\
arguments
{
\
item
{
booster
}{
Object
of
class
\
code
{
lgb
.
Booster
}}
\
item
{
filename
}{
s
aved
filename
}
\
item
{
filename
}{
S
aved
filename
}
\
item
{
num_iteration
}{
number
of
iteration
want
to
predict
with
,
NULL
or
<=
0
means
use
best
iteration
}
\
item
{
num_iteration
}{
Number
of
iterations
to
save
,
NULL
or
<=
0
means
use
best
iteration
}
\
item
{
start_iteration
}{
Index
(
1
-
based
)
of
the
first
boosting
round
to
save
.
For
example
,
passing
\
code
{
start_iteration
=
5
,
num_iteration
=
3
}
for
a
regression
model
means
"save the fifth, sixth, and seventh tree"
}
}
\
value
{
lgb
.
Booster
...
...
R-package/src/lightgbm_R.cpp
View file @
e0ac6356
...
...
@@ -1093,11 +1093,12 @@ SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,
SEXP
LGBM_BoosterSaveModel_R
(
SEXP
handle
,
SEXP
num_iteration
,
SEXP
feature_importance_type
,
SEXP
filename
)
{
SEXP
filename
,
SEXP
start_iteration
)
{
R_API_BEGIN
();
_AssertBoosterHandleNotNull
(
handle
);
const
char
*
filename_ptr
=
CHAR
(
PROTECT
(
Rf_asChar
(
filename
)));
CHECK_CALL
(
LGBM_BoosterSaveModel
(
R_ExternalPtrAddr
(
handle
),
0
,
Rf_asInteger
(
num_iteration
),
Rf_asInteger
(
feature_importance_type
),
filename_ptr
));
CHECK_CALL
(
LGBM_BoosterSaveModel
(
R_ExternalPtrAddr
(
handle
),
Rf_asInteger
(
start_iteration
)
,
Rf_asInteger
(
num_iteration
),
Rf_asInteger
(
feature_importance_type
),
filename_ptr
));
UNPROTECT
(
1
);
return
R_NilValue
;
R_API_END
();
...
...
@@ -1105,20 +1106,22 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP
LGBM_BoosterSaveModelToString_R
(
SEXP
handle
,
SEXP
num_iteration
,
SEXP
feature_importance_type
)
{
SEXP
feature_importance_type
,
SEXP
start_iteration
)
{
SEXP
cont_token
=
PROTECT
(
R_MakeUnwindCont
());
R_API_BEGIN
();
_AssertBoosterHandleNotNull
(
handle
);
int64_t
out_len
=
0
;
int64_t
buf_len
=
1024
*
1024
;
int
num_iter
=
Rf_asInteger
(
num_iteration
);
int
start_iter
=
Rf_asInteger
(
start_iteration
);
int
importance_type
=
Rf_asInteger
(
feature_importance_type
);
std
::
vector
<
char
>
inner_char_buf
(
buf_len
);
CHECK_CALL
(
LGBM_BoosterSaveModelToString
(
R_ExternalPtrAddr
(
handle
),
0
,
num_iter
,
importance_type
,
buf_len
,
&
out_len
,
inner_char_buf
.
data
()));
CHECK_CALL
(
LGBM_BoosterSaveModelToString
(
R_ExternalPtrAddr
(
handle
),
start_iter
,
num_iter
,
importance_type
,
buf_len
,
&
out_len
,
inner_char_buf
.
data
()));
SEXP
model_str
=
PROTECT
(
safe_R_raw
(
out_len
,
&
cont_token
));
// if the model string was larger than the initial buffer, call the function again, writing directly to the R object
if
(
out_len
>
buf_len
)
{
CHECK_CALL
(
LGBM_BoosterSaveModelToString
(
R_ExternalPtrAddr
(
handle
),
0
,
num_iter
,
importance_type
,
out_len
,
&
out_len
,
reinterpret_cast
<
char
*>
(
RAW
(
model_str
))));
CHECK_CALL
(
LGBM_BoosterSaveModelToString
(
R_ExternalPtrAddr
(
handle
),
start_iter
,
num_iter
,
importance_type
,
out_len
,
&
out_len
,
reinterpret_cast
<
char
*>
(
RAW
(
model_str
))));
}
else
{
std
::
copy
(
inner_char_buf
.
begin
(),
inner_char_buf
.
begin
()
+
out_len
,
reinterpret_cast
<
char
*>
(
RAW
(
model_str
)));
}
...
...
@@ -1129,7 +1132,8 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP
LGBM_BoosterDumpModel_R
(
SEXP
handle
,
SEXP
num_iteration
,
SEXP
feature_importance_type
)
{
SEXP
feature_importance_type
,
SEXP
start_iteration
)
{
SEXP
cont_token
=
PROTECT
(
R_MakeUnwindCont
());
R_API_BEGIN
();
_AssertBoosterHandleNotNull
(
handle
);
...
...
@@ -1137,13 +1141,14 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle,
int64_t
out_len
=
0
;
int64_t
buf_len
=
1024
*
1024
;
int
num_iter
=
Rf_asInteger
(
num_iteration
);
int
start_iter
=
Rf_asInteger
(
start_iteration
);
int
importance_type
=
Rf_asInteger
(
feature_importance_type
);
std
::
vector
<
char
>
inner_char_buf
(
buf_len
);
CHECK_CALL
(
LGBM_BoosterDumpModel
(
R_ExternalPtrAddr
(
handle
),
0
,
num_iter
,
importance_type
,
buf_len
,
&
out_len
,
inner_char_buf
.
data
()));
CHECK_CALL
(
LGBM_BoosterDumpModel
(
R_ExternalPtrAddr
(
handle
),
start_iter
,
num_iter
,
importance_type
,
buf_len
,
&
out_len
,
inner_char_buf
.
data
()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if
(
out_len
>
buf_len
)
{
inner_char_buf
.
resize
(
out_len
);
CHECK_CALL
(
LGBM_BoosterDumpModel
(
R_ExternalPtrAddr
(
handle
),
0
,
num_iter
,
importance_type
,
out_len
,
&
out_len
,
inner_char_buf
.
data
()));
CHECK_CALL
(
LGBM_BoosterDumpModel
(
R_ExternalPtrAddr
(
handle
),
start_iter
,
num_iter
,
importance_type
,
out_len
,
&
out_len
,
inner_char_buf
.
data
()));
}
model_str
=
PROTECT
(
safe_R_string
(
static_cast
<
R_xlen_t
>
(
1
),
&
cont_token
));
SET_STRING_ELT
(
model_str
,
0
,
safe_R_mkChar
(
inner_char_buf
.
data
(),
&
cont_token
));
...
...
@@ -1261,9 +1266,9 @@ static const R_CallMethodDef CallEntries[] = {
{
"LGBM_BoosterPredictForMatSingleRow_R"
,
(
DL_FUNC
)
&
LGBM_BoosterPredictForMatSingleRow_R
,
9
},
{
"LGBM_BoosterPredictForMatSingleRowFastInit_R"
,
(
DL_FUNC
)
&
LGBM_BoosterPredictForMatSingleRowFastInit_R
,
8
},
{
"LGBM_BoosterPredictForMatSingleRowFast_R"
,
(
DL_FUNC
)
&
LGBM_BoosterPredictForMatSingleRowFast_R
,
3
},
{
"LGBM_BoosterSaveModel_R"
,
(
DL_FUNC
)
&
LGBM_BoosterSaveModel_R
,
4
},
{
"LGBM_BoosterSaveModelToString_R"
,
(
DL_FUNC
)
&
LGBM_BoosterSaveModelToString_R
,
3
},
{
"LGBM_BoosterDumpModel_R"
,
(
DL_FUNC
)
&
LGBM_BoosterDumpModel_R
,
3
},
{
"LGBM_BoosterSaveModel_R"
,
(
DL_FUNC
)
&
LGBM_BoosterSaveModel_R
,
5
},
{
"LGBM_BoosterSaveModelToString_R"
,
(
DL_FUNC
)
&
LGBM_BoosterSaveModelToString_R
,
4
},
{
"LGBM_BoosterDumpModel_R"
,
(
DL_FUNC
)
&
LGBM_BoosterDumpModel_R
,
4
},
{
"LGBM_NullBoosterHandleError_R"
,
(
DL_FUNC
)
&
LGBM_NullBoosterHandleError_R
,
0
},
{
"LGBM_DumpParamAliases_R"
,
(
DL_FUNC
)
&
LGBM_DumpParamAliases_R
,
0
},
{
"LGBM_GetMaxThreads_R"
,
(
DL_FUNC
)
&
LGBM_GetMaxThreads_R
,
1
},
...
...
R-package/src/lightgbm_R.h
View file @
e0ac6356
...
...
@@ -809,13 +809,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMatSingleRowFast_R(
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param filename file name
* \param start_iteration Starting iteration (0 based)
* \return R NULL value
*/
LIGHTGBM_C_EXPORT
SEXP
LGBM_BoosterSaveModel_R
(
SEXP
handle
,
SEXP
num_iteration
,
SEXP
feature_importance_type
,
SEXP
filename
SEXP
filename
,
SEXP
start_iteration
);
/*!
...
...
@@ -823,12 +825,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param start_iteration Starting iteration (0 based)
* \return R character vector (length=1) with model string
*/
LIGHTGBM_C_EXPORT
SEXP
LGBM_BoosterSaveModelToString_R
(
SEXP
handle
,
SEXP
num_iteration
,
SEXP
feature_importance_type
SEXP
feature_importance_type
,
SEXP
start_iteration
);
/*!
...
...
@@ -836,12 +840,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
* \param handle Booster handle
* \param num_iteration, <= 0 means save all
* \param feature_importance_type type of feature importance, 0: split, 1: gain
* \param start_iteration Index of starting iteration (0 based)
* \return R character vector (length=1) with model JSON
*/
LIGHTGBM_C_EXPORT
SEXP
LGBM_BoosterDumpModel_R
(
SEXP
handle
,
SEXP
num_iteration
,
SEXP
feature_importance_type
SEXP
feature_importance_type
,
SEXP
start_iteration
);
/*!
...
...
R-package/tests/testthat/test_lgb.Booster.R
View file @
e0ac6356
...
...
@@ -1519,3 +1519,95 @@ test_that("LGBM_BoosterGetNumFeature_R returns correct outputs", {
ncols
<-
.Call
(
LGBM_BoosterGetNumFeature_R
,
model
$
.__enclos_env__
$
private
$
handle
)
expect_equal
(
ncols
,
ncol
(
iris
)
-
1L
)
})
# Helper function that creates a fitted model with nrounds boosting rounds
.get_test_model
<-
function
(
nrounds
)
{
set.seed
(
1L
)
data
(
agaricus.train
,
package
=
"lightgbm"
)
train
<-
agaricus.train
bst
<-
lightgbm
(
data
=
as.matrix
(
train
$
data
)
,
label
=
train
$
label
,
params
=
list
(
objective
=
"binary"
,
num_threads
=
.LGB_MAX_THREADS
)
,
nrounds
=
nrounds
,
verbose
=
.LGB_VERBOSITY
)
return
(
bst
)
}
# Simplified version of lgb.model.dt.tree()
.get_trees_from_dump
<-
function
(
x
)
{
parsed
<-
jsonlite
::
fromJSON
(
txt
=
x
,
simplifyVector
=
TRUE
,
simplifyDataFrame
=
FALSE
,
simplifyMatrix
=
FALSE
,
flatten
=
FALSE
)
return
(
lapply
(
parsed
$
tree_info
,
FUN
=
.single_tree_parse
))
}
test_that
(
"num_iteration and start_iteration work for lgb.dump()"
,
{
bst
<-
.get_test_model
(
5L
)
first2
<-
.get_trees_from_dump
(
lgb.dump
(
bst
,
num_iteration
=
2L
))
last3
<-
.get_trees_from_dump
(
lgb.dump
(
bst
,
num_iteration
=
3L
,
start_iteration
=
3L
)
)
all5
<-
.get_trees_from_dump
(
lgb.dump
(
bst
))
too_many
<-
.get_trees_from_dump
(
lgb.dump
(
bst
,
num_iteration
=
10L
))
expect_equal
(
data.table
::
rbindlist
(
c
(
first2
,
last3
)),
data.table
::
rbindlist
(
all5
)
)
expect_equal
(
too_many
,
all5
)
})
test_that
(
"num_iteration and start_iteration work for lgb.save()"
,
{
.get_n_trees
<-
function
(
x
)
{
return
(
length
(
.get_trees_from_dump
(
lgb.dump
(
x
))))
}
.save_and_load
<-
function
(
bst
,
...
)
{
model_file
<-
tempfile
(
fileext
=
".model"
)
lgb.save
(
bst
,
model_file
,
...
)
return
(
lgb.load
(
model_file
))
}
bst
<-
.get_test_model
(
5L
)
n_first2
<-
.get_n_trees
(
.save_and_load
(
bst
,
num_iteration
=
2L
))
n_last3
<-
.get_n_trees
(
.save_and_load
(
bst
,
num_iteration
=
3L
,
start_iteration
=
3L
)
)
n_all5
<-
.get_n_trees
(
.save_and_load
(
bst
))
n_too_many
<-
.get_n_trees
(
.save_and_load
(
bst
,
num_iteration
=
10L
))
expect_equal
(
n_first2
,
2L
)
expect_equal
(
n_last3
,
3L
)
expect_equal
(
n_all5
,
5L
)
expect_equal
(
n_too_many
,
5L
)
})
test_that
(
"num_iteration and start_iteration work for save_model_to_string()"
,
{
.get_n_trees_from_string
<-
function
(
x
)
{
return
(
sum
(
gregexpr
(
"Tree="
,
x
,
fixed
=
TRUE
)[[
1L
]]
>
0L
))
}
bst
<-
.get_test_model
(
5L
)
n_first2
<-
.get_n_trees_from_string
(
bst
$
save_model_to_string
(
num_iteration
=
2L
)
)
n_last3
<-
.get_n_trees_from_string
(
bst
$
save_model_to_string
(
num_iteration
=
3L
,
start_iteration
=
3L
)
)
n_all5
<-
.get_n_trees_from_string
(
bst
$
save_model_to_string
())
n_too_many
<-
.get_n_trees_from_string
(
bst
$
save_model_to_string
(
num_iteration
=
10L
)
)
expect_equal
(
n_first2
,
2L
)
expect_equal
(
n_last3
,
3L
)
expect_equal
(
n_all5
,
5L
)
expect_equal
(
n_too_many
,
5L
)
})
R-package/tests/testthat/test_lgb.model.dt.tree.R
View file @
e0ac6356
...
...
@@ -156,3 +156,29 @@ for (model_name in names(models)) {
expect_true
(
all
(
counts
>
1L
&
counts
<=
N
))
})
}
test_that
(
"num_iteration and start_iteration work as expected"
,
{
set.seed
(
1L
)
data
(
agaricus.train
,
package
=
"lightgbm"
)
train
<-
agaricus.train
bst
<-
lightgbm
(
data
=
as.matrix
(
train
$
data
)
,
label
=
train
$
label
,
params
=
list
(
objective
=
"binary"
,
num_threads
=
.LGB_MAX_THREADS
)
,
nrounds
=
5L
,
verbose
=
.LGB_VERBOSITY
)
first2
<-
lgb.model.dt.tree
(
bst
,
num_iteration
=
2L
)
last3
<-
lgb.model.dt.tree
(
bst
,
num_iteration
=
3L
,
start_iteration
=
3L
)
all5
<-
lgb.model.dt.tree
(
bst
)
too_many
<-
lgb.model.dt.tree
(
bst
,
num_iteration
=
10L
)
expect_equal
(
data.table
::
rbindlist
(
list
(
first2
,
last3
)),
all5
)
expect_equal
(
too_many
,
all5
)
# Check tree indices
expect_equal
(
unique
(
first2
[[
"tree_index"
]]),
0L
:
1L
)
expect_equal
(
unique
(
last3
[[
"tree_index"
]]),
2L
:
4L
)
expect_equal
(
unique
(
all5
[[
"tree_index"
]]),
0L
:
4L
)
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment