Unverified Commit ba15a16a authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] save all param values into model file (#2589)

* save all param values into model file

* revert storing predict params

* do not save params for predict and convert tasks

* fixed test: 10 is found successfully for default 100

* specify more params as no-save
parent 2051223b
...@@ -315,7 +315,7 @@ def gen_parameter_code(config_hpp, config_out_cpp): ...@@ -315,7 +315,7 @@ def gen_parameter_code(config_hpp, config_out_cpp):
str_to_write += " std::stringstream str_buf;\n" str_to_write += " std::stringstream str_buf;\n"
for x in infos: for x in infos:
for y in x: for y in x:
if "[doc-only]" in y: if "[doc-only]" in y or "[no-save]" in y:
continue continue
param_type = y["inner_type"][0] param_type = y["inner_type"][0]
name = y["name"][0] name = y["name"][0]
......
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
* Licensed under the MIT License. See LICENSE file in the project root for license information. * Licensed under the MIT License. See LICENSE file in the project root for license information.
* *
* \note * \note
* desc and descl2 fields must be written in reStructuredText format; * - desc and descl2 fields must be written in reStructuredText format;
* nested sections can be placed only at the bottom of parent's section * - nested sections can be placed only at the bottom of parent's section;
* - [doc-only] tag indicates that only documentation for this param should be generated and all other actions are performed manually;
* - [no-save] tag indicates that this param should not be saved into a model text representation.
*/ */
#ifndef LIGHTGBM_CONFIG_H_ #ifndef LIGHTGBM_CONFIG_H_
#define LIGHTGBM_CONFIG_H_ #define LIGHTGBM_CONFIG_H_
...@@ -83,12 +85,14 @@ struct Config { ...@@ -83,12 +85,14 @@ struct Config {
#pragma region Core Parameters #pragma region Core Parameters
// [no-save]
// [doc-only] // [doc-only]
// alias = config_file // alias = config_file
// desc = path of config file // desc = path of config file
// desc = **Note**: can be used only in CLI version // desc = **Note**: can be used only in CLI version
std::string config = ""; std::string config = "";
// [no-save]
// [doc-only] // [doc-only]
// type = enum // type = enum
// default = train // default = train
...@@ -482,6 +486,7 @@ struct Config { ...@@ -482,6 +486,7 @@ struct Config {
// desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug // desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
int verbosity = 1; int verbosity = 1;
// [no-save]
// alias = model_input, model_in // alias = model_input, model_in
// desc = filename of input model // desc = filename of input model
// desc = for ``prediction`` task, this model will be applied to prediction data // desc = for ``prediction`` task, this model will be applied to prediction data
...@@ -489,11 +494,13 @@ struct Config { ...@@ -489,11 +494,13 @@ struct Config {
// desc = **Note**: can be used only in CLI version // desc = **Note**: can be used only in CLI version
std::string input_model = ""; std::string input_model = "";
// [no-save]
// alias = model_output, model_out // alias = model_output, model_out
// desc = filename of output model in training // desc = filename of output model in training
// desc = **Note**: can be used only in CLI version // desc = **Note**: can be used only in CLI version
std::string output_model = "LightGBM_model.txt"; std::string output_model = "LightGBM_model.txt";
// [no-save]
// alias = save_period // alias = save_period
// desc = frequency of saving model file snapshot // desc = frequency of saving model file snapshot
// desc = set this to positive value to enable this function. For example, the model file will be snapshotted at each iteration if ``snapshot_freq=1`` // desc = set this to positive value to enable this function. For example, the model file will be snapshotted at each iteration if ``snapshot_freq=1``
...@@ -626,6 +633,7 @@ struct Config { ...@@ -626,6 +633,7 @@ struct Config {
// desc = see `this file <https://github.com/microsoft/LightGBM/tree/master/examples/regression/forced_bins.json>`__ as an example // desc = see `this file <https://github.com/microsoft/LightGBM/tree/master/examples/regression/forced_bins.json>`__ as an example
std::string forcedbins_filename = ""; std::string forcedbins_filename = "";
// [no-save]
// alias = is_save_binary, is_save_binary_file // alias = is_save_binary, is_save_binary_file
// desc = if ``true``, LightGBM will save the dataset (including validation data) to a binary file. This speed ups the data loading for the next time // desc = if ``true``, LightGBM will save the dataset (including validation data) to a binary file. This speed ups the data loading for the next time
// desc = **Note**: ``init_score`` is not saved in binary file // desc = **Note**: ``init_score`` is not saved in binary file
...@@ -636,22 +644,26 @@ struct Config { ...@@ -636,22 +644,26 @@ struct Config {
#pragma region Predict Parameters #pragma region Predict Parameters
// [no-save]
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = used to specify how many trained iterations will be used in prediction // desc = used to specify how many trained iterations will be used in prediction
// desc = ``<= 0`` means no limit // desc = ``<= 0`` means no limit
int num_iteration_predict = -1; int num_iteration_predict = -1;
// [no-save]
// alias = is_predict_raw_score, predict_rawscore, raw_score // alias = is_predict_raw_score, predict_rawscore, raw_score
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = set this to ``true`` to predict only the raw scores // desc = set this to ``true`` to predict only the raw scores
// desc = set this to ``false`` to predict transformed scores // desc = set this to ``false`` to predict transformed scores
bool predict_raw_score = false; bool predict_raw_score = false;
// [no-save]
// alias = is_predict_leaf_index, leaf_index // alias = is_predict_leaf_index, leaf_index
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = set this to ``true`` to predict with leaf index of all trees // desc = set this to ``true`` to predict with leaf index of all trees
bool predict_leaf_index = false; bool predict_leaf_index = false;
// [no-save]
// alias = is_predict_contrib, contrib // alias = is_predict_contrib, contrib
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = set this to ``true`` to estimate `SHAP values <https://arxiv.org/abs/1706.06060>`__, which represent how each feature contributes to each prediction // desc = set this to ``true`` to estimate `SHAP values <https://arxiv.org/abs/1706.06060>`__, which represent how each feature contributes to each prediction
...@@ -660,6 +672,7 @@ struct Config { ...@@ -660,6 +672,7 @@ struct Config {
// desc = **Note**: unlike the shap package, with ``predict_contrib`` we return a matrix with an extra column, where the last column is the expected value // desc = **Note**: unlike the shap package, with ``predict_contrib`` we return a matrix with an extra column, where the last column is the expected value
bool predict_contrib = false; bool predict_contrib = false;
// [no-save]
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = control whether or not LightGBM raises an error when you try to predict on data with a different number of features than the training data // desc = control whether or not LightGBM raises an error when you try to predict on data with a different number of features than the training data
// desc = if ``false`` (the default), a fatal error will be raised if the number of features in the dataset you predict on differs from the number seen during training // desc = if ``false`` (the default), a fatal error will be raised if the number of features in the dataset you predict on differs from the number seen during training
...@@ -667,18 +680,22 @@ struct Config { ...@@ -667,18 +680,22 @@ struct Config {
// desc = **Note**: be very careful setting this parameter to ``true`` // desc = **Note**: be very careful setting this parameter to ``true``
bool predict_disable_shape_check = false; bool predict_disable_shape_check = false;
// [no-save]
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = if ``true``, will use early-stopping to speed up the prediction. May affect the accuracy // desc = if ``true``, will use early-stopping to speed up the prediction. May affect the accuracy
bool pred_early_stop = false; bool pred_early_stop = false;
// [no-save]
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = the frequency of checking early-stopping prediction // desc = the frequency of checking early-stopping prediction
int pred_early_stop_freq = 10; int pred_early_stop_freq = 10;
// [no-save]
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = the threshold of margin in early-stopping prediction // desc = the threshold of margin in early-stopping prediction
double pred_early_stop_margin = 10.0; double pred_early_stop_margin = 10.0;
// [no-save]
// alias = predict_result, prediction_result, predict_name, prediction_name, pred_name, name_pred // alias = predict_result, prediction_result, predict_name, prediction_name, pred_name, name_pred
// desc = used only in ``prediction`` task // desc = used only in ``prediction`` task
// desc = filename of prediction result // desc = filename of prediction result
...@@ -689,12 +706,14 @@ struct Config { ...@@ -689,12 +706,14 @@ struct Config {
#pragma region Convert Parameters #pragma region Convert Parameters
// [no-save]
// desc = used only in ``convert_model`` task // desc = used only in ``convert_model`` task
// desc = only ``cpp`` is supported yet; for conversion model to other languages consider using `m2cgen <https://github.com/BayesWitnesses/m2cgen>`__ utility // desc = only ``cpp`` is supported yet; for conversion model to other languages consider using `m2cgen <https://github.com/BayesWitnesses/m2cgen>`__ utility
// desc = if ``convert_model_language`` is set and ``task=train``, the model will be also converted // desc = if ``convert_model_language`` is set and ``task=train``, the model will be also converted
// desc = **Note**: can be used only in CLI version // desc = **Note**: can be used only in CLI version
std::string convert_model_language = ""; std::string convert_model_language = "";
// [no-save]
// alias = convert_model_file // alias = convert_model_file
// desc = used only in ``convert_model`` task // desc = used only in ``convert_model`` task
// desc = output filename of converted model // desc = output filename of converted model
...@@ -820,12 +839,14 @@ struct Config { ...@@ -820,12 +839,14 @@ struct Config {
// desc = support multiple metrics, separated by ``,`` // desc = support multiple metrics, separated by ``,``
std::vector<std::string> metric; std::vector<std::string> metric;
// [no-save]
// check = >0 // check = >0
// alias = output_freq // alias = output_freq
// desc = frequency for metric output // desc = frequency for metric output
// desc = **Note**: can be used only in CLI version // desc = **Note**: can be used only in CLI version
int metric_freq = 1; int metric_freq = 1;
// [no-save]
// alias = training_metric, is_training_metric, train_metric // alias = training_metric, is_training_metric, train_metric
// desc = set this to ``true`` to output metric result over training dataset // desc = set this to ``true`` to output metric result over training dataset
// desc = **Note**: can be used only in CLI version // desc = **Note**: can be used only in CLI version
......
...@@ -1751,7 +1751,7 @@ class Booster(object): ...@@ -1751,7 +1751,7 @@ class Booster(object):
self.set_network(machines, self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400), local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120), listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines)) num_machines=params.setdefault("num_machines", num_machines))
break break
# construct booster object # construct booster object
train_set.construct() train_set.construct()
...@@ -2641,7 +2641,7 @@ class Booster(object): ...@@ -2641,7 +2641,7 @@ class Booster(object):
train_set = Dataset(data, label, silent=True) train_set = Dataset(data, label, silent=True)
new_params = copy.deepcopy(self.params) new_params = copy.deepcopy(self.params)
new_params['refit_decay_rate'] = decay_rate new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set, silent=True) new_booster = Booster(new_params, train_set)
# Copy models # Copy models
_safe_call(_LIB.LGBM_BoosterMerge( _safe_call(_LIB.LGBM_BoosterMerge(
new_booster.handle, new_booster.handle,
......
...@@ -146,13 +146,13 @@ def train(params, train_set, num_boost_round=100, ...@@ -146,13 +146,13 @@ def train(params, train_set, num_boost_round=100,
if alias in params: if alias in params:
num_boost_round = params.pop(alias) num_boost_round = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
break params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"): for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params: if alias in params:
early_stopping_rounds = params.pop(alias) early_stopping_rounds = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
break params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.pop('first_metric_only', False) first_metric_only = params.get('first_metric_only', False)
if num_boost_round <= 0: if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.") raise ValueError("num_boost_round should be greater than zero.")
...@@ -504,13 +504,13 @@ def cv(params, train_set, num_boost_round=100, ...@@ -504,13 +504,13 @@ def cv(params, train_set, num_boost_round=100,
if alias in params: if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias) num_boost_round = params.pop(alias)
break params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"): for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params: if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias) early_stopping_rounds = params.pop(alias)
break params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.pop('first_metric_only', False) first_metric_only = params.get('first_metric_only', False)
if num_boost_round <= 0: if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.") raise ValueError("num_boost_round should be greater than zero.")
......
...@@ -641,9 +641,6 @@ std::string Config::SaveMembersToString() const { ...@@ -641,9 +641,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_lazy: " << Common::Join(cegb_penalty_feature_lazy, ",") << "]\n"; str_buf << "[cegb_penalty_feature_lazy: " << Common::Join(cegb_penalty_feature_lazy, ",") << "]\n";
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n"; str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n"; str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[input_model: " << input_model << "]\n";
str_buf << "[output_model: " << output_model << "]\n";
str_buf << "[snapshot_freq: " << snapshot_freq << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n"; str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n"; str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n"; str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
...@@ -663,18 +660,6 @@ std::string Config::SaveMembersToString() const { ...@@ -663,18 +660,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[ignore_column: " << ignore_column << "]\n"; str_buf << "[ignore_column: " << ignore_column << "]\n";
str_buf << "[categorical_feature: " << categorical_feature << "]\n"; str_buf << "[categorical_feature: " << categorical_feature << "]\n";
str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n"; str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n";
str_buf << "[save_binary: " << save_binary << "]\n";
str_buf << "[num_iteration_predict: " << num_iteration_predict << "]\n";
str_buf << "[predict_raw_score: " << predict_raw_score << "]\n";
str_buf << "[predict_leaf_index: " << predict_leaf_index << "]\n";
str_buf << "[predict_contrib: " << predict_contrib << "]\n";
str_buf << "[predict_disable_shape_check: " << predict_disable_shape_check << "]\n";
str_buf << "[pred_early_stop: " << pred_early_stop << "]\n";
str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n";
str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n";
str_buf << "[output_result: " << output_result << "]\n";
str_buf << "[convert_model_language: " << convert_model_language << "]\n";
str_buf << "[convert_model: " << convert_model << "]\n";
str_buf << "[objective_seed: " << objective_seed << "]\n"; str_buf << "[objective_seed: " << objective_seed << "]\n";
str_buf << "[num_class: " << num_class << "]\n"; str_buf << "[num_class: " << num_class << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n"; str_buf << "[is_unbalance: " << is_unbalance << "]\n";
...@@ -689,8 +674,6 @@ std::string Config::SaveMembersToString() const { ...@@ -689,8 +674,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n"; str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n";
str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n"; str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n"; str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n";
str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n"; str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n";
str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n"; str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n";
......
...@@ -747,6 +747,8 @@ class TestEngine(unittest.TestCase): ...@@ -747,6 +747,8 @@ class TestEngine(unittest.TestCase):
ret_origin = train_and_predict(init_model=gbm) ret_origin = train_and_predict(init_model=gbm)
other_ret = [] other_ret = []
gbm.save_model('lgb.model') gbm.save_model('lgb.model')
with open('lgb.model') as f: # check all params are logged into model file correctly
self.assertNotEqual(f.read().find("[num_iterations: 10]"), -1)
other_ret.append(train_and_predict(init_model='lgb.model')) other_ret.append(train_and_predict(init_model='lgb.model'))
gbm_load = lgb.Booster(model_file='lgb.model') gbm_load = lgb.Booster(model_file='lgb.model')
other_ret.append(train_and_predict(init_model=gbm_load)) other_ret.append(train_and_predict(init_model=gbm_load))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment