Unverified Commit ba15a16a authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] save all param values into model file (#2589)

* save all param values into model file

* revert storing predict params

* do not save params for predict and convert tasks

* fixed test: 10 is found successfully for default 100

* specify more params as no-save
parent 2051223b
......@@ -315,7 +315,7 @@ def gen_parameter_code(config_hpp, config_out_cpp):
str_to_write += " std::stringstream str_buf;\n"
for x in infos:
for y in x:
if "[doc-only]" in y:
if "[doc-only]" in y or "[no-save]" in y:
continue
param_type = y["inner_type"][0]
name = y["name"][0]
......
......@@ -3,8 +3,10 @@
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*
* \note
* desc and descl2 fields must be written in reStructuredText format;
* nested sections can be placed only at the bottom of parent's section
* - desc and descl2 fields must be written in reStructuredText format;
* - nested sections can be placed only at the bottom of parent's section;
* - [doc-only] tag indicates that only documentation for this param should be generated and all other actions are performed manually;
* - [no-save] tag indicates that this param should not be saved into a model text representation.
*/
#ifndef LIGHTGBM_CONFIG_H_
#define LIGHTGBM_CONFIG_H_
......@@ -83,12 +85,14 @@ struct Config {
#pragma region Core Parameters
// [no-save]
// [doc-only]
// alias = config_file
// desc = path of config file
// desc = **Note**: can be used only in CLI version
std::string config = "";
// [no-save]
// [doc-only]
// type = enum
// default = train
......@@ -482,6 +486,7 @@ struct Config {
// desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
int verbosity = 1;
// [no-save]
// alias = model_input, model_in
// desc = filename of input model
// desc = for ``prediction`` task, this model will be applied to prediction data
......@@ -489,11 +494,13 @@ struct Config {
// desc = **Note**: can be used only in CLI version
std::string input_model = "";
// [no-save]
// alias = model_output, model_out
// desc = filename of output model in training
// desc = **Note**: can be used only in CLI version
std::string output_model = "LightGBM_model.txt";
// [no-save]
// alias = save_period
// desc = frequency of saving model file snapshot
// desc = set this to positive value to enable this function. For example, the model file will be snapshotted at each iteration if ``snapshot_freq=1``
......@@ -626,6 +633,7 @@ struct Config {
// desc = see `this file <https://github.com/microsoft/LightGBM/tree/master/examples/regression/forced_bins.json>`__ as an example
std::string forcedbins_filename = "";
// [no-save]
// alias = is_save_binary, is_save_binary_file
// desc = if ``true``, LightGBM will save the dataset (including validation data) to a binary file. This speed ups the data loading for the next time
// desc = **Note**: ``init_score`` is not saved in binary file
......@@ -636,22 +644,26 @@ struct Config {
#pragma region Predict Parameters
// [no-save]
// desc = used only in ``prediction`` task
// desc = used to specify how many trained iterations will be used in prediction
// desc = ``<= 0`` means no limit
int num_iteration_predict = -1;
// [no-save]
// alias = is_predict_raw_score, predict_rawscore, raw_score
// desc = used only in ``prediction`` task
// desc = set this to ``true`` to predict only the raw scores
// desc = set this to ``false`` to predict transformed scores
bool predict_raw_score = false;
// [no-save]
// alias = is_predict_leaf_index, leaf_index
// desc = used only in ``prediction`` task
// desc = set this to ``true`` to predict with leaf index of all trees
bool predict_leaf_index = false;
// [no-save]
// alias = is_predict_contrib, contrib
// desc = used only in ``prediction`` task
// desc = set this to ``true`` to estimate `SHAP values <https://arxiv.org/abs/1706.06060>`__, which represent how each feature contributes to each prediction
......@@ -660,6 +672,7 @@ struct Config {
// desc = **Note**: unlike the shap package, with ``predict_contrib`` we return a matrix with an extra column, where the last column is the expected value
bool predict_contrib = false;
// [no-save]
// desc = used only in ``prediction`` task
// desc = control whether or not LightGBM raises an error when you try to predict on data with a different number of features than the training data
// desc = if ``false`` (the default), a fatal error will be raised if the number of features in the dataset you predict on differs from the number seen during training
......@@ -667,18 +680,22 @@ struct Config {
// desc = **Note**: be very careful setting this parameter to ``true``
bool predict_disable_shape_check = false;
// [no-save]
// desc = used only in ``prediction`` task
// desc = if ``true``, will use early-stopping to speed up the prediction. May affect the accuracy
bool pred_early_stop = false;
// [no-save]
// desc = used only in ``prediction`` task
// desc = the frequency of checking early-stopping prediction
int pred_early_stop_freq = 10;
// [no-save]
// desc = used only in ``prediction`` task
// desc = the threshold of margin in early-stopping prediction
double pred_early_stop_margin = 10.0;
// [no-save]
// alias = predict_result, prediction_result, predict_name, prediction_name, pred_name, name_pred
// desc = used only in ``prediction`` task
// desc = filename of prediction result
......@@ -689,12 +706,14 @@ struct Config {
#pragma region Convert Parameters
// [no-save]
// desc = used only in ``convert_model`` task
// desc = only ``cpp`` is supported yet; for conversion model to other languages consider using `m2cgen <https://github.com/BayesWitnesses/m2cgen>`__ utility
// desc = if ``convert_model_language`` is set and ``task=train``, the model will be also converted
// desc = **Note**: can be used only in CLI version
std::string convert_model_language = "";
// [no-save]
// alias = convert_model_file
// desc = used only in ``convert_model`` task
// desc = output filename of converted model
......@@ -820,12 +839,14 @@ struct Config {
// desc = support multiple metrics, separated by ``,``
std::vector<std::string> metric;
// [no-save]
// check = >0
// alias = output_freq
// desc = frequency for metric output
// desc = **Note**: can be used only in CLI version
int metric_freq = 1;
// [no-save]
// alias = training_metric, is_training_metric, train_metric
// desc = set this to ``true`` to output metric result over training dataset
// desc = **Note**: can be used only in CLI version
......
......@@ -1751,7 +1751,7 @@ class Booster(object):
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines))
num_machines=params.setdefault("num_machines", num_machines))
break
# construct booster object
train_set.construct()
......@@ -2641,7 +2641,7 @@ class Booster(object):
train_set = Dataset(data, label, silent=True)
new_params = copy.deepcopy(self.params)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set, silent=True)
new_booster = Booster(new_params, train_set)
# Copy models
_safe_call(_LIB.LGBM_BoosterMerge(
new_booster.handle,
......
......@@ -146,13 +146,13 @@ def train(params, train_set, num_boost_round=100,
if alias in params:
num_boost_round = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
break
params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
early_stopping_rounds = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
break
first_metric_only = params.pop('first_metric_only', False)
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)
if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
......@@ -504,13 +504,13 @@ def cv(params, train_set, num_boost_round=100,
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
break
params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
break
first_metric_only = params.pop('first_metric_only', False)
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)
if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
......
......@@ -641,9 +641,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_lazy: " << Common::Join(cegb_penalty_feature_lazy, ",") << "]\n";
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[input_model: " << input_model << "]\n";
str_buf << "[output_model: " << output_model << "]\n";
str_buf << "[snapshot_freq: " << snapshot_freq << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
......@@ -663,18 +660,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[ignore_column: " << ignore_column << "]\n";
str_buf << "[categorical_feature: " << categorical_feature << "]\n";
str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n";
str_buf << "[save_binary: " << save_binary << "]\n";
str_buf << "[num_iteration_predict: " << num_iteration_predict << "]\n";
str_buf << "[predict_raw_score: " << predict_raw_score << "]\n";
str_buf << "[predict_leaf_index: " << predict_leaf_index << "]\n";
str_buf << "[predict_contrib: " << predict_contrib << "]\n";
str_buf << "[predict_disable_shape_check: " << predict_disable_shape_check << "]\n";
str_buf << "[pred_early_stop: " << pred_early_stop << "]\n";
str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n";
str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n";
str_buf << "[output_result: " << output_result << "]\n";
str_buf << "[convert_model_language: " << convert_model_language << "]\n";
str_buf << "[convert_model: " << convert_model << "]\n";
str_buf << "[objective_seed: " << objective_seed << "]\n";
str_buf << "[num_class: " << num_class << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n";
......@@ -689,8 +674,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n";
str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n";
str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n";
str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n";
......
......@@ -747,6 +747,8 @@ class TestEngine(unittest.TestCase):
ret_origin = train_and_predict(init_model=gbm)
other_ret = []
gbm.save_model('lgb.model')
with open('lgb.model') as f: # check all params are logged into model file correctly
self.assertNotEqual(f.read().find("[num_iterations: 10]"), -1)
other_ret.append(train_and_predict(init_model='lgb.model'))
gbm_load = lgb.Booster(model_file='lgb.model')
other_ret.append(train_and_predict(init_model=gbm_load))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment