Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
8b720844
Unverified
Commit
8b720844
authored
Oct 11, 2022
by
José Morales
Committed by
GitHub
Oct 11, 2022
Browse files
[python-package][R-package] load parameters from model file (fixes #2613) (#5424)
parent
c134d3d9
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
368 additions
and
6 deletions
+368
-6
R-package/R/lgb.Booster.R
R-package/R/lgb.Booster.R
+15
-0
R-package/src/lightgbm_R.cpp
R-package/src/lightgbm_R.cpp
+22
-0
R-package/src/lightgbm_R.h
R-package/src/lightgbm_R.h
+9
-0
R-package/tests/testthat/test_lgb.Booster.R
R-package/tests/testthat/test_lgb.Booster.R
+18
-6
helpers/parameter_generator.py
helpers/parameter_generator.py
+27
-0
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+2
-0
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+14
-0
include/LightGBM/config.h
include/LightGBM/config.h
+1
-0
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+25
-0
src/boosting/gbdt.h
src/boosting/gbdt.h
+54
-0
src/c_api.cpp
src/c_api.cpp
+15
-0
src/io/config_auto.cpp
src/io/config_auto.cpp
+137
-0
tests/python_package_test/test_engine.py
tests/python_package_test/test_engine.py
+29
-0
No files found.
R-package/R/lgb.Booster.R
View file @
8b720844
...
...
@@ -77,6 +77,7 @@ Booster <- R6::R6Class(
LGBM_BoosterCreateFromModelfile_R
,
modelfile
)
params
<-
private
$
get_loaded_param
(
handle
)
}
else
if
(
!
is.null
(
model_str
))
{
...
...
@@ -727,6 +728,20 @@ Booster <- R6::R6Class(
},
get_loaded_param
=
function
(
handle
)
{
params_str
<-
.Call
(
LGBM_BoosterGetLoadedParam_R
,
handle
)
params
<-
jsonlite
::
fromJSON
(
params_str
)
if
(
"interaction_constraints"
%in%
names
(
params
))
{
params
[[
"interaction_constraints"
]]
<-
lapply
(
params
[[
"interaction_constraints"
]],
function
(
x
)
x
+
1L
)
}
return
(
params
)
},
inner_eval
=
function
(
data_name
,
data_idx
,
feval
=
NULL
)
{
# Check for unknown dataset (over the maximum provided range)
...
...
R-package/src/lightgbm_R.cpp
View file @
8b720844
...
...
@@ -1183,6 +1183,27 @@ SEXP LGBM_DumpParamAliases_R() {
R_API_END
();
}
SEXP
LGBM_BoosterGetLoadedParam_R
(
SEXP
handle
)
{
SEXP
cont_token
=
PROTECT
(
R_MakeUnwindCont
());
R_API_BEGIN
();
_AssertBoosterHandleNotNull
(
handle
);
SEXP
params_str
;
int64_t
out_len
=
0
;
int64_t
buf_len
=
1024
*
1024
;
std
::
vector
<
char
>
inner_char_buf
(
buf_len
);
CHECK_CALL
(
LGBM_BoosterGetLoadedParam
(
R_ExternalPtrAddr
(
handle
),
buf_len
,
&
out_len
,
inner_char_buf
.
data
()));
// if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
if
(
out_len
>
buf_len
)
{
inner_char_buf
.
resize
(
out_len
);
CHECK_CALL
(
LGBM_BoosterGetLoadedParam
(
R_ExternalPtrAddr
(
handle
),
out_len
,
&
out_len
,
inner_char_buf
.
data
()));
}
params_str
=
PROTECT
(
safe_R_string
(
static_cast
<
R_xlen_t
>
(
1
),
&
cont_token
));
SET_STRING_ELT
(
params_str
,
0
,
safe_R_mkChar
(
inner_char_buf
.
data
(),
&
cont_token
));
UNPROTECT
(
2
);
return
params_str
;
R_API_END
();
}
// .Call() calls
static
const
R_CallMethodDef
CallEntries
[]
=
{
{
"LGBM_HandleIsNull_R"
,
(
DL_FUNC
)
&
LGBM_HandleIsNull_R
,
1
},
...
...
@@ -1211,6 +1232,7 @@ static const R_CallMethodDef CallEntries[] = {
{
"LGBM_BoosterResetParameter_R"
,
(
DL_FUNC
)
&
LGBM_BoosterResetParameter_R
,
2
},
{
"LGBM_BoosterGetNumClasses_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetNumClasses_R
,
2
},
{
"LGBM_BoosterGetNumFeature_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetNumFeature_R
,
1
},
{
"LGBM_BoosterGetLoadedParam_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetLoadedParam_R
,
1
},
{
"LGBM_BoosterUpdateOneIter_R"
,
(
DL_FUNC
)
&
LGBM_BoosterUpdateOneIter_R
,
1
},
{
"LGBM_BoosterUpdateOneIterCustom_R"
,
(
DL_FUNC
)
&
LGBM_BoosterUpdateOneIterCustom_R
,
4
},
{
"LGBM_BoosterRollbackOneIter_R"
,
(
DL_FUNC
)
&
LGBM_BoosterRollbackOneIter_R
,
1
},
...
...
R-package/src/lightgbm_R.h
View file @
8b720844
...
...
@@ -266,6 +266,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R(
SEXP
model_str
);
/*!
* \brief Get parameters as JSON string.
* \param handle Booster handle
* \return R character vector (length=1) with parameters in JSON format
*/
LIGHTGBM_C_EXPORT
SEXP
LGBM_BoosterGetLoadedParam_R
(
SEXP
handle
);
/*!
* \brief Merge model in two Boosters to first handle
* \param handle handle primary Booster handle, will merge other handle to this
...
...
R-package/tests/testthat/test_lgb.Booster.R
View file @
8b720844
...
...
@@ -172,15 +172,24 @@ test_that("Loading a Booster from a text file works", {
data
(
agaricus.test
,
package
=
"lightgbm"
)
train
<-
agaricus.train
test
<-
agaricus.test
params
<-
list
(
num_leaves
=
4L
,
boosting
=
"rf"
,
bagging_fraction
=
0.8
,
bagging_freq
=
1L
,
boost_from_average
=
FALSE
,
categorical_feature
=
c
(
1L
,
2L
)
,
interaction_constraints
=
list
(
c
(
1L
,
2L
),
1L
)
,
feature_contri
=
rep
(
0.5
,
ncol
(
train
$
data
))
,
metric
=
c
(
"mape"
,
"average_precision"
)
,
learning_rate
=
1.0
,
objective
=
"binary"
,
verbosity
=
VERBOSITY
)
bst
<-
lightgbm
(
data
=
as.matrix
(
train
$
data
)
,
label
=
train
$
label
,
params
=
list
(
num_leaves
=
4L
,
learning_rate
=
1.0
,
objective
=
"binary"
,
verbose
=
VERBOSITY
)
,
params
=
params
,
nrounds
=
2L
)
expect_true
(
lgb.is.Booster
(
bst
))
...
...
@@ -199,6 +208,9 @@ test_that("Loading a Booster from a text file works", {
)
pred2
<-
predict
(
bst2
,
test
$
data
)
expect_identical
(
pred
,
pred2
)
# check that the parameters are loaded correctly
expect_equal
(
bst2
$
params
[
names
(
params
)],
params
)
})
test_that
(
"boosters with linear models at leaves can be written to text file and re-loaded successfully"
,
{
...
...
helpers/parameter_generator.py
View file @
8b720844
...
...
@@ -6,6 +6,7 @@ with list of all parameters, aliases table and other routines
along with parameters description in LightGBM/docs/Parameters.rst file
from the information in LightGBM/include/LightGBM/config.h file.
"""
import
re
from
collections
import
defaultdict
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
...
...
@@ -373,6 +374,32 @@ def gen_parameter_code(
}
"""
str_to_write
+=
"""const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
static std::unordered_map<std::string, std::string> map({"""
int_t_pat
=
re
.
compile
(
r
'int\d+_t'
)
# the following are stored as comma separated strings but are arrays in the wrappers
overrides
=
{
'categorical_feature'
:
'vector<int>'
,
'ignore_column'
:
'vector<int>'
,
'interaction_constraints'
:
'vector<vector<int>>'
,
}
for
x
in
infos
:
for
y
in
x
:
name
=
y
[
"name"
][
0
]
if
name
==
'task'
:
continue
if
name
in
overrides
:
param_type
=
overrides
[
name
]
else
:
param_type
=
int_t_pat
.
sub
(
'int'
,
y
[
"inner_type"
][
0
]).
replace
(
'std::'
,
''
)
str_to_write
+=
'
\n
{"'
+
name
+
'", "'
+
param_type
+
'"},'
str_to_write
+=
"""
});
return map;
}
"""
str_to_write
+=
"} // namespace LightGBM
\n
"
with
open
(
config_out_cpp
,
"w"
)
as
config_out_cpp_file
:
config_out_cpp_file
.
write
(
str_to_write
)
...
...
include/LightGBM/boosting.h
View file @
8b720844
...
...
@@ -313,6 +313,8 @@ class LIGHTGBM_EXPORT Boosting {
*/
static
Boosting
*
CreateBoosting
(
const
std
::
string
&
type
,
const
char
*
filename
);
virtual
std
::
string
GetLoadedParam
()
const
=
0
;
virtual
bool
IsLinear
()
const
{
return
false
;
}
virtual
std
::
string
ParserConfigStr
()
const
=
0
;
...
...
include/LightGBM/c_api.h
View file @
8b720844
...
...
@@ -595,6 +595,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(const char* model_str,
int
*
out_num_iterations
,
BoosterHandle
*
out
);
/*!
* \brief Get parameters as JSON string.
* \param handle Handle of booster.
* \param buffer_len Allocated space for string.
* \param[out] out_len Actual size of string.
* \param[out] out_str JSON string containing parameters.
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT
int
LGBM_BoosterGetLoadedParam
(
BoosterHandle
handle
,
int64_t
buffer_len
,
int64_t
*
out_len
,
char
*
out_str
);
/*!
* \brief Free space for booster.
* \param handle Handle of booster to be freed
...
...
include/LightGBM/config.h
View file @
8b720844
...
...
@@ -1077,6 +1077,7 @@ struct Config {
static
const
std
::
unordered_set
<
std
::
string
>&
parameter_set
();
std
::
vector
<
std
::
vector
<
double
>>
auc_mu_weights_matrix
;
std
::
vector
<
std
::
vector
<
int
>>
interaction_constraints_vector
;
static
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
ParameterTypes
();
static
const
std
::
string
DumpAliases
();
private:
...
...
python-package/lightgbm/basic.py
View file @
8b720844
...
...
@@ -2816,6 +2816,9 @@ class Booster:
ctypes
.
byref
(
out_num_class
)))
self
.
__num_class
=
out_num_class
.
value
self
.
pandas_categorical
=
_load_pandas_categorical
(
file_name
=
model_file
)
if
params
:
_log_warning
(
'Ignoring params argument, using parameters from model file.'
)
params
=
self
.
_get_loaded_param
()
elif
model_str
is
not
None
:
self
.
model_from_string
(
model_str
)
else
:
...
...
@@ -2864,6 +2867,28 @@ class Booster:
state
[
'handle'
]
=
handle
self
.
__dict__
.
update
(
state
)
def
_get_loaded_param
(
self
)
->
Dict
[
str
,
Any
]:
buffer_len
=
1
<<
20
tmp_out_len
=
ctypes
.
c_int64
(
0
)
string_buffer
=
ctypes
.
create_string_buffer
(
buffer_len
)
ptr_string_buffer
=
ctypes
.
c_char_p
(
*
[
ctypes
.
addressof
(
string_buffer
)])
_safe_call
(
_LIB
.
LGBM_BoosterGetLoadedParam
(
self
.
handle
,
ctypes
.
c_int64
(
buffer_len
),
ctypes
.
byref
(
tmp_out_len
),
ptr_string_buffer
))
actual_len
=
tmp_out_len
.
value
# if buffer length is not long enough, re-allocate a buffer
if
actual_len
>
buffer_len
:
string_buffer
=
ctypes
.
create_string_buffer
(
actual_len
)
ptr_string_buffer
=
ctypes
.
c_char_p
(
*
[
ctypes
.
addressof
(
string_buffer
)])
_safe_call
(
_LIB
.
LGBM_BoosterGetLoadedParam
(
self
.
handle
,
ctypes
.
c_int64
(
actual_len
),
ctypes
.
byref
(
tmp_out_len
),
ptr_string_buffer
))
return
json
.
loads
(
string_buffer
.
value
.
decode
(
'utf-8'
))
def
free_dataset
(
self
)
->
"Booster"
:
"""Free Booster's Datasets.
...
...
src/boosting/gbdt.h
View file @
8b720844
...
...
@@ -157,6 +157,60 @@ class GBDT : public GBDTBase {
*/
int
GetCurrentIteration
()
const
override
{
return
static_cast
<
int
>
(
models_
.
size
())
/
num_tree_per_iteration_
;
}
/*!
* \brief Get parameters as a JSON string
*/
std
::
string
GetLoadedParam
()
const
override
{
if
(
loaded_parameter_
.
empty
())
{
return
std
::
string
(
"{}"
);
}
const
auto
param_types
=
Config
::
ParameterTypes
();
const
auto
lines
=
Common
::
Split
(
loaded_parameter_
.
c_str
(),
"
\n
"
);
bool
first
=
true
;
std
::
stringstream
str_buf
;
str_buf
<<
"{"
;
for
(
const
auto
&
line
:
lines
)
{
const
auto
pair
=
Common
::
Split
(
line
.
c_str
(),
":"
);
if
(
pair
[
1
]
==
" ]"
)
continue
;
if
(
first
)
{
first
=
false
;
str_buf
<<
"
\"
"
;
}
else
{
str_buf
<<
",
\"
"
;
}
const
auto
param
=
pair
[
0
].
substr
(
1
);
const
auto
value_str
=
pair
[
1
].
substr
(
1
,
pair
[
1
].
size
()
-
2
);
const
auto
param_type
=
param_types
.
at
(
param
);
str_buf
<<
param
<<
"
\"
: "
;
if
(
param_type
==
"string"
)
{
str_buf
<<
"
\"
"
<<
value_str
<<
"
\"
"
;
}
else
if
(
param_type
==
"int"
)
{
int
value
;
Common
::
Atoi
(
value_str
.
c_str
(),
&
value
);
str_buf
<<
value
;
}
else
if
(
param_type
==
"double"
)
{
double
value
;
Common
::
Atof
(
value_str
.
c_str
(),
&
value
);
str_buf
<<
value
;
}
else
if
(
param_type
==
"bool"
)
{
bool
value
=
value_str
==
"1"
;
str_buf
<<
std
::
boolalpha
<<
value
;
}
else
if
(
param_type
.
substr
(
0
,
6
)
==
"vector"
)
{
str_buf
<<
"["
;
if
(
param_type
.
substr
(
7
,
6
)
==
"string"
)
{
const
auto
parts
=
Common
::
Split
(
value_str
.
c_str
(),
","
);
str_buf
<<
"
\"
"
<<
Common
::
Join
(
parts
,
"
\"
,
\"
"
)
<<
"
\"
"
;
}
else
{
str_buf
<<
value_str
;
}
str_buf
<<
"]"
;
}
}
str_buf
<<
"}"
;
return
str_buf
.
str
();
}
/*!
* \brief Can use early stopping for prediction or not
* \return True if cannot use early stopping for prediction
...
...
src/c_api.cpp
View file @
8b720844
...
...
@@ -1748,6 +1748,21 @@ int LGBM_BoosterLoadModelFromString(
API_END
();
}
int
LGBM_BoosterGetLoadedParam
(
BoosterHandle
handle
,
int64_t
buffer_len
,
int64_t
*
out_len
,
char
*
out_str
)
{
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
std
::
string
params
=
ref_booster
->
GetBoosting
()
->
GetLoadedParam
();
*
out_len
=
static_cast
<
int64_t
>
(
params
.
size
())
+
1
;
if
(
*
out_len
<=
buffer_len
)
{
std
::
memcpy
(
out_str
,
params
.
c_str
(),
*
out_len
);
}
API_END
();
}
#ifdef _MSC_VER
#pragma warning(disable : 4702)
#endif
...
...
src/io/config_auto.cpp
View file @
8b720844
...
...
@@ -894,4 +894,141 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
return
map
;
}
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
Config
::
ParameterTypes
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
map
({
{
"config"
,
"string"
},
{
"objective"
,
"string"
},
{
"boosting"
,
"string"
},
{
"data"
,
"string"
},
{
"valid"
,
"vector<string>"
},
{
"num_iterations"
,
"int"
},
{
"learning_rate"
,
"double"
},
{
"num_leaves"
,
"int"
},
{
"tree_learner"
,
"string"
},
{
"num_threads"
,
"int"
},
{
"device_type"
,
"string"
},
{
"seed"
,
"int"
},
{
"deterministic"
,
"bool"
},
{
"force_col_wise"
,
"bool"
},
{
"force_row_wise"
,
"bool"
},
{
"histogram_pool_size"
,
"double"
},
{
"max_depth"
,
"int"
},
{
"min_data_in_leaf"
,
"int"
},
{
"min_sum_hessian_in_leaf"
,
"double"
},
{
"bagging_fraction"
,
"double"
},
{
"pos_bagging_fraction"
,
"double"
},
{
"neg_bagging_fraction"
,
"double"
},
{
"bagging_freq"
,
"int"
},
{
"bagging_seed"
,
"int"
},
{
"feature_fraction"
,
"double"
},
{
"feature_fraction_bynode"
,
"double"
},
{
"feature_fraction_seed"
,
"int"
},
{
"extra_trees"
,
"bool"
},
{
"extra_seed"
,
"int"
},
{
"early_stopping_round"
,
"int"
},
{
"first_metric_only"
,
"bool"
},
{
"max_delta_step"
,
"double"
},
{
"lambda_l1"
,
"double"
},
{
"lambda_l2"
,
"double"
},
{
"linear_lambda"
,
"double"
},
{
"min_gain_to_split"
,
"double"
},
{
"drop_rate"
,
"double"
},
{
"max_drop"
,
"int"
},
{
"skip_drop"
,
"double"
},
{
"xgboost_dart_mode"
,
"bool"
},
{
"uniform_drop"
,
"bool"
},
{
"drop_seed"
,
"int"
},
{
"top_rate"
,
"double"
},
{
"other_rate"
,
"double"
},
{
"min_data_per_group"
,
"int"
},
{
"max_cat_threshold"
,
"int"
},
{
"cat_l2"
,
"double"
},
{
"cat_smooth"
,
"double"
},
{
"max_cat_to_onehot"
,
"int"
},
{
"top_k"
,
"int"
},
{
"monotone_constraints"
,
"vector<int>"
},
{
"monotone_constraints_method"
,
"string"
},
{
"monotone_penalty"
,
"double"
},
{
"feature_contri"
,
"vector<double>"
},
{
"forcedsplits_filename"
,
"string"
},
{
"refit_decay_rate"
,
"double"
},
{
"cegb_tradeoff"
,
"double"
},
{
"cegb_penalty_split"
,
"double"
},
{
"cegb_penalty_feature_lazy"
,
"vector<double>"
},
{
"cegb_penalty_feature_coupled"
,
"vector<double>"
},
{
"path_smooth"
,
"double"
},
{
"interaction_constraints"
,
"vector<vector<int>>"
},
{
"verbosity"
,
"int"
},
{
"input_model"
,
"string"
},
{
"output_model"
,
"string"
},
{
"saved_feature_importance_type"
,
"int"
},
{
"snapshot_freq"
,
"int"
},
{
"linear_tree"
,
"bool"
},
{
"max_bin"
,
"int"
},
{
"max_bin_by_feature"
,
"vector<int>"
},
{
"min_data_in_bin"
,
"int"
},
{
"bin_construct_sample_cnt"
,
"int"
},
{
"data_random_seed"
,
"int"
},
{
"is_enable_sparse"
,
"bool"
},
{
"enable_bundle"
,
"bool"
},
{
"use_missing"
,
"bool"
},
{
"zero_as_missing"
,
"bool"
},
{
"feature_pre_filter"
,
"bool"
},
{
"pre_partition"
,
"bool"
},
{
"two_round"
,
"bool"
},
{
"header"
,
"bool"
},
{
"label_column"
,
"string"
},
{
"weight_column"
,
"string"
},
{
"group_column"
,
"string"
},
{
"ignore_column"
,
"vector<int>"
},
{
"categorical_feature"
,
"vector<int>"
},
{
"forcedbins_filename"
,
"string"
},
{
"save_binary"
,
"bool"
},
{
"precise_float_parser"
,
"bool"
},
{
"parser_config_file"
,
"string"
},
{
"start_iteration_predict"
,
"int"
},
{
"num_iteration_predict"
,
"int"
},
{
"predict_raw_score"
,
"bool"
},
{
"predict_leaf_index"
,
"bool"
},
{
"predict_contrib"
,
"bool"
},
{
"predict_disable_shape_check"
,
"bool"
},
{
"pred_early_stop"
,
"bool"
},
{
"pred_early_stop_freq"
,
"int"
},
{
"pred_early_stop_margin"
,
"double"
},
{
"output_result"
,
"string"
},
{
"convert_model_language"
,
"string"
},
{
"convert_model"
,
"string"
},
{
"objective_seed"
,
"int"
},
{
"num_class"
,
"int"
},
{
"is_unbalance"
,
"bool"
},
{
"scale_pos_weight"
,
"double"
},
{
"sigmoid"
,
"double"
},
{
"boost_from_average"
,
"bool"
},
{
"reg_sqrt"
,
"bool"
},
{
"alpha"
,
"double"
},
{
"fair_c"
,
"double"
},
{
"poisson_max_delta_step"
,
"double"
},
{
"tweedie_variance_power"
,
"double"
},
{
"lambdarank_truncation_level"
,
"int"
},
{
"lambdarank_norm"
,
"bool"
},
{
"label_gain"
,
"vector<double>"
},
{
"metric"
,
"vector<string>"
},
{
"metric_freq"
,
"int"
},
{
"is_provide_training_metric"
,
"bool"
},
{
"eval_at"
,
"vector<int>"
},
{
"multi_error_top_k"
,
"int"
},
{
"auc_mu_weights"
,
"vector<double>"
},
{
"num_machines"
,
"int"
},
{
"local_listen_port"
,
"int"
},
{
"time_out"
,
"int"
},
{
"machine_list_filename"
,
"string"
},
{
"machines"
,
"string"
},
{
"gpu_platform_id"
,
"int"
},
{
"gpu_device_id"
,
"int"
},
{
"gpu_use_dp"
,
"bool"
},
{
"num_gpu"
,
"int"
},
});
return
map
;
}
}
// namespace LightGBM
tests/python_package_test/test_engine.py
View file @
8b720844
...
...
@@ -1211,6 +1211,35 @@ def test_feature_name_with_non_ascii():
assert
feature_names
==
gbm2
.
feature_name
()
def
test_parameters_are_loaded_from_model_file
(
tmp_path
):
X
=
np
.
hstack
([
np
.
random
.
rand
(
100
,
1
),
np
.
random
.
randint
(
0
,
5
,
(
100
,
2
))])
y
=
np
.
random
.
rand
(
100
)
ds
=
lgb
.
Dataset
(
X
,
y
)
params
=
{
'bagging_fraction'
:
0.8
,
'bagging_freq'
:
2
,
'boosting'
:
'rf'
,
'feature_contri'
:
[
0.5
,
0.5
,
0.5
],
'feature_fraction'
:
0.7
,
'boost_from_average'
:
False
,
'interaction_constraints'
:
[[
0
,
1
],
[
0
]],
'metric'
:
[
'l2'
,
'rmse'
],
'num_leaves'
:
5
,
'num_threads'
:
1
,
}
model_file
=
tmp_path
/
'model.txt'
lgb
.
train
(
params
,
ds
,
num_boost_round
=
1
,
categorical_feature
=
[
1
,
2
]).
save_model
(
model_file
)
bst
=
lgb
.
Booster
(
model_file
=
model_file
)
set_params
=
{
k
:
bst
.
params
[
k
]
for
k
in
params
.
keys
()}
assert
set_params
==
params
assert
bst
.
params
[
'categorical_feature'
]
==
[
1
,
2
]
# check that passing parameters to the constructor raises warning and ignores them
with
pytest
.
warns
(
UserWarning
,
match
=
'Ignoring params argument'
):
bst2
=
lgb
.
Booster
(
params
=
{
'num_leaves'
:
7
},
model_file
=
model_file
)
assert
bst
.
params
==
bst2
.
params
def
test_save_load_copy_pickle
():
def
train_and_predict
(
init_model
=
None
,
return_model
=
False
):
X
,
y
=
make_synthetic_regression
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment