Commit 0a9d4cc2 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[docs] generate parameters description from config file. Final stage (#1421)

* removed excess whitespaces

* don't use built-in name for variable

* simplified line parsing

* chanched link to related

* run parameter_generator.py

* removed old targets

* use tuples instead of list where possible

* hotfix for descriptions were erased and only last one was kept

* run parameter_generator.py

* separated checks from aliases section
parent c0147cbe
This diff is collapsed.
...@@ -3,7 +3,7 @@ Documentation ...@@ -3,7 +3,7 @@ Documentation
Documentation for LightGBM is generated using `Sphinx <http://www.sphinx-doc.org/>`__. Documentation for LightGBM is generated using `Sphinx <http://www.sphinx-doc.org/>`__.
List of parameters and their descriptions in `Parameters.rst <https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst>`__ List of parameters and their descriptions in `Parameters.rst <./Parameters.rst>`__
is generated automatically from comments in `config file <https://github.com/Microsoft/LightGBM/blob/master/include/LightGBM/config.h>`__ is generated automatically from comments in `config file <https://github.com/Microsoft/LightGBM/blob/master/include/LightGBM/config.h>`__
by `this script <https://github.com/Microsoft/LightGBM/blob/master/helper/parameter_generator.py>`__. by `this script <https://github.com/Microsoft/LightGBM/blob/master/helper/parameter_generator.py>`__.
......
...@@ -30,18 +30,18 @@ def GetParameterInfos(config_hpp): ...@@ -30,18 +30,18 @@ def GetParameterInfos(config_hpp):
elif cur_key is not None: elif cur_key is not None:
line = line.strip() line = line.strip()
if line.startswith("//"): if line.startswith("//"):
tokens = line[2:].split("=") key, _, val = line[2:].partition("=")
key = tokens[0].strip() key = key.strip()
val = '='.join(tokens[1:]).strip() val = val.strip()
if key not in cur_info: if key not in cur_info:
if key == "descl2": if key == "descl2" and "desc" not in cur_info:
cur_info["desc"] = [] cur_info["desc"] = []
else: elif key != "descl2":
cur_info[key] = [] cur_info[key] = []
if key == "desc": if key == "desc":
cur_info["desc"].append(["l1", val]) cur_info["desc"].append(("l1", val))
elif key == "descl2": elif key == "descl2":
cur_info["desc"].append(["l2", val]) cur_info["desc"].append(("l2", val))
else: else:
cur_info[key].append(val) cur_info[key].append(val)
elif line: elif line:
...@@ -79,22 +79,22 @@ def GetAlias(infos): ...@@ -79,22 +79,22 @@ def GetAlias(infos):
name = y["name"][0] name = y["name"][0]
alias = y["alias"][0].split(',') alias = y["alias"][0].split(',')
for name2 in alias: for name2 in alias:
pairs.append([name2.strip(), name]) pairs.append((name2.strip(), name))
return pairs return pairs
def SetOneVarFromString(name, type, checks): def SetOneVarFromString(name, param_type, checks):
ret = "" ret = ""
univar_mapper = {"int": "GetInt", "double": "GetDouble", "bool": "GetBool", "std::string": "GetString"} univar_mapper = {"int": "GetInt", "double": "GetDouble", "bool": "GetBool", "std::string": "GetString"}
if "vector" not in type: if "vector" not in param_type:
ret += " %s(params, \"%s\", &%s);\n" % (univar_mapper[type], name, name) ret += " %s(params, \"%s\", &%s);\n" % (univar_mapper[param_type], name, name)
if len(checks) > 0: if len(checks) > 0:
for check in checks: for check in checks:
ret += " CHECK(%s %s);\n" % (name, check) ret += " CHECK(%s %s);\n" % (name, check)
ret += "\n" ret += "\n"
else: else:
ret += " if (GetString(params, \"%s\", &tmp_str)) {\n" % (name) ret += " if (GetString(params, \"%s\", &tmp_str)) {\n" % (name)
type2 = type.split("<")[1][:-1] type2 = param_type.split("<")[1][:-1]
if type2 == "std::string": if type2 == "std::string":
ret += " %s = Common::Split(tmp_str.c_str(), ',');\n" % (name) ret += " %s = Common::Split(tmp_str.c_str(), ',');\n" % (name)
else: else:
...@@ -141,10 +141,10 @@ def GenParameterDescription(sections, descriptions, params_rst): ...@@ -141,10 +141,10 @@ def GenParameterDescription(sections, descriptions, params_rst):
if checks_len > 1: if checks_len > 1:
number1, sign1 = parse_check(checks[0]) number1, sign1 = parse_check(checks[0])
number2, sign2 = parse_check(checks[1], reverse=True) number2, sign2 = parse_check(checks[1], reverse=True)
checks_str = ', ``{0} {1} {2} {3} {4}``'.format(number2, sign2, name, sign1, number1) checks_str = ', constraints: ``{0} {1} {2} {3} {4}``'.format(number2, sign2, name, sign1, number1)
elif checks_len == 1: elif checks_len == 1:
number, sign = parse_check(checks[0]) number, sign = parse_check(checks[0])
checks_str = ', ``{0} {1} {2}``'.format(name, sign, number) checks_str = ', constraints: ``{0} {1} {2}``'.format(name, sign, number)
else: else:
checks_str = '' checks_str = ''
main_desc = '- ``{0}``, default = ``{1}``, type = {2}{3}{4}{5}'.format(name, default, param_type, options_str, aliases_str, checks_str) main_desc = '- ``{0}``, default = ``{1}``, type = {2}{3}{4}{5}'.format(name, default, param_type, options_str, aliases_str, checks_str)
...@@ -173,12 +173,12 @@ def GenParameterCode(config_hpp, config_out_cpp): ...@@ -173,12 +173,12 @@ def GenParameterCode(config_hpp, config_out_cpp):
# alias table # alias table
str_to_write += "std::unordered_map<std::string, std::string> Config::alias_table({\n" str_to_write += "std::unordered_map<std::string, std::string> Config::alias_table({\n"
for pair in alias: for pair in alias:
str_to_write += " {\"%s\", \"%s\"}, \n" % (pair[0], pair[1]) str_to_write += " {\"%s\", \"%s\"},\n" % (pair[0], pair[1])
str_to_write += "});\n\n" str_to_write += "});\n\n"
# names # names
str_to_write += "std::unordered_set<std::string> Config::parameter_set({\n" str_to_write += "std::unordered_set<std::string> Config::parameter_set({\n"
for name in names: for name in names:
str_to_write += " \"%s\", \n" % (name) str_to_write += " \"%s\",\n" % (name)
str_to_write += "});\n\n" str_to_write += "});\n\n"
# from strings # from strings
str_to_write += "void Config::GetMembersFromString(const std::unordered_map<std::string, std::string>& params) {\n" str_to_write += "void Config::GetMembersFromString(const std::unordered_map<std::string, std::string>& params) {\n"
...@@ -187,12 +187,12 @@ def GenParameterCode(config_hpp, config_out_cpp): ...@@ -187,12 +187,12 @@ def GenParameterCode(config_hpp, config_out_cpp):
for y in x: for y in x:
if "[doc-only]" in y: if "[doc-only]" in y:
continue continue
type = y["inner_type"][0] param_type = y["inner_type"][0]
name = y["name"][0] name = y["name"][0]
checks = [] checks = []
if "check" in y: if "check" in y:
checks = y["check"] checks = y["check"]
tmp = SetOneVarFromString(name, type, checks) tmp = SetOneVarFromString(name, param_type, checks)
str_to_write += tmp str_to_write += tmp
# tails # tails
str_to_write += "}\n\n" str_to_write += "}\n\n"
...@@ -202,10 +202,10 @@ def GenParameterCode(config_hpp, config_out_cpp): ...@@ -202,10 +202,10 @@ def GenParameterCode(config_hpp, config_out_cpp):
for y in x: for y in x:
if "[doc-only]" in y: if "[doc-only]" in y:
continue continue
type = y["inner_type"][0] param_type = y["inner_type"][0]
name = y["name"][0] name = y["name"][0]
if "vector" in type: if "vector" in param_type:
if "int8" in type: if "int8" in param_type:
str_to_write += " str_buf << \"[%s: \" << Common::Join(Common::ArrayCast<int8_t, int>(%s),\",\") << \"]\\n\";\n" % (name, name) str_to_write += " str_buf << \"[%s: \" << Common::Join(Common::ArrayCast<int8_t, int>(%s),\",\") << \"]\\n\";\n" % (name, name)
else: else:
str_to_write += " str_buf << \"[%s: \" << Common::Join(%s,\",\") << \"]\\n\";\n" % (name, name) str_to_write += " str_buf << \"[%s: \" << Common::Join(%s,\",\") << \"]\\n\";\n" % (name, name)
......
...@@ -4,9 +4,9 @@ namespace LightGBM { ...@@ -4,9 +4,9 @@ namespace LightGBM {
std::unordered_map<std::string, std::string> Config::alias_table({ std::unordered_map<std::string, std::string> Config::alias_table({
{"config_file", "config"}, {"config_file", "config"},
{"task_type", "task"}, {"task_type", "task"},
{"application", "objective"},
{"app", "objective"},
{"objective_type", "objective"}, {"objective_type", "objective"},
{"app", "objective"},
{"application", "objective"},
{"boosting_type", "boosting"}, {"boosting_type", "boosting"},
{"boost", "boosting"}, {"boost", "boosting"},
{"train", "data"}, {"train", "data"},
...@@ -14,6 +14,7 @@ std::unordered_map<std::string, std::string> Config::alias_table({ ...@@ -14,6 +14,7 @@ std::unordered_map<std::string, std::string> Config::alias_table({
{"data_filename", "data"}, {"data_filename", "data"},
{"test", "valid"}, {"test", "valid"},
{"valid_data", "valid"}, {"valid_data", "valid"},
{"valid_data_file", "valid"},
{"test_data", "valid"}, {"test_data", "valid"},
{"valid_filenames", "valid"}, {"valid_filenames", "valid"},
{"num_iteration", "num_iterations"}, {"num_iteration", "num_iterations"},
...@@ -56,50 +57,62 @@ std::unordered_map<std::string, std::string> Config::alias_table({ ...@@ -56,50 +57,62 @@ std::unordered_map<std::string, std::string> Config::alias_table({
{"topk", "top_k"}, {"topk", "top_k"},
{"mc", "monotone_constraints"}, {"mc", "monotone_constraints"},
{"monotone_constraint", "monotone_constraints"}, {"monotone_constraint", "monotone_constraints"},
{"fs", "forcedsplits_filename"},
{"forced_splits_filename", "forcedsplits_filename"}, {"forced_splits_filename", "forcedsplits_filename"},
{"forced_splits_file", "forcedsplits_filename"}, {"forced_splits_file", "forcedsplits_filename"},
{"forced_splits", "forcedsplits_filename"}, {"forced_splits", "forcedsplits_filename"},
{"verbose", "verbosity"},
{"subsample_for_bin", "bin_construct_sample_cnt"},
{"model_output", "output_model"}, {"model_output", "output_model"},
{"model_out", "output_model"}, {"model_out", "output_model"},
{"model_input", "input_model"}, {"model_input", "input_model"},
{"model_in", "input_model"}, {"model_in", "input_model"},
{"predict_result", "output_result"}, {"predict_result", "output_result"},
{"prediction_result", "output_result"}, {"prediction_result", "output_result"},
{"init_score_filename", "initscore_filename"},
{"init_score_file", "initscore_filename"},
{"init_score", "initscore_filename"},
{"input_init_score", "initscore_filename"},
{"valid_data_init_scores", "valid_data_initscores"},
{"valid_init_score_file", "valid_data_initscores"},
{"valid_init_score", "valid_data_initscores"},
{"is_pre_partition", "pre_partition"}, {"is_pre_partition", "pre_partition"},
{"is_enable_bundle", "enable_bundle"},
{"bundle", "enable_bundle"},
{"is_sparse", "is_enable_sparse"}, {"is_sparse", "is_enable_sparse"},
{"enable_sparse", "is_enable_sparse"}, {"enable_sparse", "is_enable_sparse"},
{"sparse", "is_enable_sparse"},
{"two_round_loading", "two_round"}, {"two_round_loading", "two_round"},
{"use_two_round_loading", "two_round"}, {"use_two_round_loading", "two_round"},
{"is_save_binary", "save_binary"}, {"is_save_binary", "save_binary"},
{"is_save_binary_file", "save_binary"}, {"is_save_binary_file", "save_binary"},
{"verbose", "verbosity"}, {"load_from_binary_file", "enable_load_from_binary_file"},
{"binary_load", "enable_load_from_binary_file"},
{"load_binary", "enable_load_from_binary_file"},
{"has_header", "header"}, {"has_header", "header"},
{"label", "label_column"}, {"label", "label_column"},
{"weight", "weight_column"}, {"weight", "weight_column"},
{"query_column", "group_column"},
{"group", "group_column"}, {"group", "group_column"},
{"group_id", "group_column"},
{"query_column", "group_column"},
{"query", "group_column"}, {"query", "group_column"},
{"query_id", "group_column"},
{"ignore_feature", "ignore_column"}, {"ignore_feature", "ignore_column"},
{"blacklist", "ignore_column"}, {"blacklist", "ignore_column"},
{"categorical_column", "categorical_feature"},
{"cat_feature", "categorical_feature"}, {"cat_feature", "categorical_feature"},
{"categorical_column", "categorical_feature"},
{"cat_column", "categorical_feature"}, {"cat_column", "categorical_feature"},
{"raw_score", "predict_raw_score"},
{"is_predict_raw_score", "predict_raw_score"}, {"is_predict_raw_score", "predict_raw_score"},
{"predict_rawscore", "predict_raw_score"}, {"predict_rawscore", "predict_raw_score"},
{"leaf_index", "predict_leaf_index"}, {"raw_score", "predict_raw_score"},
{"is_predict_leaf_index", "predict_leaf_index"}, {"is_predict_leaf_index", "predict_leaf_index"},
{"contrib", "predict_contrib"}, {"leaf_index", "predict_leaf_index"},
{"is_predict_contrib", "predict_contrib"}, {"is_predict_contrib", "predict_contrib"},
{"subsample_for_bin", "bin_construct_sample_cnt"}, {"contrib", "predict_contrib"},
{"init_score_filename", "initscore_filename"}, {"convert_model_file", "convert_model"},
{"init_score_file", "initscore_filename"},
{"init_score", "initscore_filename"},
{"valid_data_init_scores", "valid_data_initscores"},
{"valid_init_score_file", "valid_data_initscores"},
{"valid_init_score", "valid_data_initscores"},
{"num_classes", "num_class"}, {"num_classes", "num_class"},
{"unbalanced_sets", "is_unbalance"}, {"unbalanced_sets", "is_unbalance"},
{"metrics", "metric"},
{"metric_types", "metric"}, {"metric_types", "metric"},
{"output_freq", "metric_freq"}, {"output_freq", "metric_freq"},
{"training_metric", "is_provide_training_metric"}, {"training_metric", "is_provide_training_metric"},
...@@ -109,8 +122,11 @@ std::unordered_map<std::string, std::string> Config::alias_table({ ...@@ -109,8 +122,11 @@ std::unordered_map<std::string, std::string> Config::alias_table({
{"ndcg_at", "eval_at"}, {"ndcg_at", "eval_at"},
{"num_machine", "num_machines"}, {"num_machine", "num_machines"},
{"local_port", "local_listen_port"}, {"local_port", "local_listen_port"},
{"port", "local_listen_port"},
{"machine_list_file", "machine_list_filename"},
{"machine_list", "machine_list_filename"},
{"mlist", "machine_list_filename"}, {"mlist", "machine_list_filename"},
{"works", "machines"}, {"workers", "machines"},
{"nodes", "machines"}, {"nodes", "machines"},
}); });
...@@ -157,18 +173,28 @@ std::unordered_set<std::string> Config::parameter_set({ ...@@ -157,18 +173,28 @@ std::unordered_set<std::string> Config::parameter_set({
"top_k", "top_k",
"monotone_constraints", "monotone_constraints",
"forcedsplits_filename", "forcedsplits_filename",
"verbosity",
"max_bin", "max_bin",
"min_data_in_bin", "min_data_in_bin",
"bin_construct_sample_cnt",
"histogram_pool_size",
"data_random_seed", "data_random_seed",
"output_model", "output_model",
"snapshot_freq",
"input_model", "input_model",
"output_result", "output_result",
"initscore_filename",
"valid_data_initscores",
"pre_partition", "pre_partition",
"enable_bundle",
"max_conflict_rate",
"is_enable_sparse", "is_enable_sparse",
"sparse_threshold", "sparse_threshold",
"use_missing",
"zero_as_missing",
"two_round", "two_round",
"save_binary", "save_binary",
"verbosity", "enable_load_from_binary_file",
"header", "header",
"label_column", "label_column",
"weight_column", "weight_column",
...@@ -182,30 +208,20 @@ std::unordered_set<std::string> Config::parameter_set({ ...@@ -182,30 +208,20 @@ std::unordered_set<std::string> Config::parameter_set({
"pred_early_stop", "pred_early_stop",
"pred_early_stop_freq", "pred_early_stop_freq",
"pred_early_stop_margin", "pred_early_stop_margin",
"bin_construct_sample_cnt",
"use_missing",
"zero_as_missing",
"initscore_filename",
"valid_data_initscores",
"histogram_pool_size",
"enable_load_from_binary_file",
"enable_bundle",
"max_conflict_rate",
"snapshot_freq",
"convert_model_language", "convert_model_language",
"convert_model", "convert_model",
"num_class", "num_class",
"is_unbalance",
"scale_pos_weight",
"sigmoid", "sigmoid",
"boost_from_average",
"reg_sqrt",
"alpha", "alpha",
"fair_c", "fair_c",
"poisson_max_delta_step", "poisson_max_delta_step",
"boost_from_average",
"is_unbalance",
"scale_pos_weight",
"reg_sqrt",
"tweedie_variance_power", "tweedie_variance_power",
"label_gain",
"max_position", "max_position",
"label_gain",
"metric", "metric",
"metric_freq", "metric_freq",
"is_provide_training_metric", "is_provide_training_metric",
...@@ -232,7 +248,7 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -232,7 +248,7 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
CHECK(num_iterations >=0); CHECK(num_iterations >=0);
GetDouble(params, "learning_rate", &learning_rate); GetDouble(params, "learning_rate", &learning_rate);
CHECK(learning_rate >0); CHECK(learning_rate >0.0);
GetInt(params, "num_leaves", &num_leaves); GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves >1); CHECK(num_leaves >1);
...@@ -245,9 +261,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -245,9 +261,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
CHECK(min_data_in_leaf >=0); CHECK(min_data_in_leaf >=0);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf); GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf >=0.0);
GetDouble(params, "bagging_fraction", &bagging_fraction); GetDouble(params, "bagging_fraction", &bagging_fraction);
CHECK(bagging_fraction >0); CHECK(bagging_fraction >0.0);
CHECK(bagging_fraction <=1.0); CHECK(bagging_fraction <=1.0);
GetInt(params, "bagging_freq", &bagging_freq); GetInt(params, "bagging_freq", &bagging_freq);
...@@ -255,7 +272,7 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -255,7 +272,7 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "bagging_seed", &bagging_seed); GetInt(params, "bagging_seed", &bagging_seed);
GetDouble(params, "feature_fraction", &feature_fraction); GetDouble(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction >0); CHECK(feature_fraction >0.0);
CHECK(feature_fraction <=1.0); CHECK(feature_fraction <=1.0);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed); GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
...@@ -265,21 +282,22 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -265,21 +282,22 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetDouble(params, "max_delta_step", &max_delta_step); GetDouble(params, "max_delta_step", &max_delta_step);
GetDouble(params, "lambda_l1", &lambda_l1); GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >=0); CHECK(lambda_l1 >=0.0);
GetDouble(params, "lambda_l2", &lambda_l2); GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >=0); CHECK(lambda_l2 >=0.0);
GetDouble(params, "min_gain_to_split", &min_gain_to_split); GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >=0.0);
GetDouble(params, "drop_rate", &drop_rate); GetDouble(params, "drop_rate", &drop_rate);
CHECK(drop_rate >=0); CHECK(drop_rate >=0.0);
CHECK(drop_rate <=1.0); CHECK(drop_rate <=1.0);
GetInt(params, "max_drop", &max_drop); GetInt(params, "max_drop", &max_drop);
GetDouble(params, "skip_drop", &skip_drop); GetDouble(params, "skip_drop", &skip_drop);
CHECK(skip_drop >=0); CHECK(skip_drop >=0.0);
CHECK(skip_drop <=1.0); CHECK(skip_drop <=1.0);
GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode); GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
...@@ -289,11 +307,11 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -289,11 +307,11 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "drop_seed", &drop_seed); GetInt(params, "drop_seed", &drop_seed);
GetDouble(params, "top_rate", &top_rate); GetDouble(params, "top_rate", &top_rate);
CHECK(top_rate >=0); CHECK(top_rate >=0.0);
CHECK(top_rate <=1.0); CHECK(top_rate <=1.0);
GetDouble(params, "other_rate", &other_rate); GetDouble(params, "other_rate", &other_rate);
CHECK(other_rate >=0); CHECK(other_rate >=0.0);
CHECK(other_rate <=1.0); CHECK(other_rate <=1.0);
GetInt(params, "min_data_per_group", &min_data_per_group); GetInt(params, "min_data_per_group", &min_data_per_group);
...@@ -303,15 +321,16 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -303,15 +321,16 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
CHECK(max_cat_threshold >0); CHECK(max_cat_threshold >0);
GetDouble(params, "cat_l2", &cat_l2); GetDouble(params, "cat_l2", &cat_l2);
CHECK(cat_l2 >=0); CHECK(cat_l2 >=0.0);
GetDouble(params, "cat_smooth", &cat_smooth); GetDouble(params, "cat_smooth", &cat_smooth);
CHECK(cat_smooth >=0); CHECK(cat_smooth >=0.0);
GetInt(params, "max_cat_to_onehot", &max_cat_to_onehot); GetInt(params, "max_cat_to_onehot", &max_cat_to_onehot);
CHECK(max_cat_to_onehot >0); CHECK(max_cat_to_onehot >0);
GetInt(params, "top_k", &top_k); GetInt(params, "top_k", &top_k);
CHECK(top_k >0);
if (GetString(params, "monotone_constraints", &tmp_str)) { if (GetString(params, "monotone_constraints", &tmp_str)) {
monotone_constraints = Common::StringToArray<int8_t>(tmp_str, ','); monotone_constraints = Common::StringToArray<int8_t>(tmp_str, ',');
...@@ -319,33 +338,58 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -319,33 +338,58 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetString(params, "forcedsplits_filename", &forcedsplits_filename); GetString(params, "forcedsplits_filename", &forcedsplits_filename);
GetInt(params, "verbosity", &verbosity);
GetInt(params, "max_bin", &max_bin); GetInt(params, "max_bin", &max_bin);
CHECK(max_bin >1); CHECK(max_bin >1);
GetInt(params, "min_data_in_bin", &min_data_in_bin); GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin >0); CHECK(min_data_in_bin >0);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt >0);
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "data_random_seed", &data_random_seed); GetInt(params, "data_random_seed", &data_random_seed);
GetString(params, "output_model", &output_model); GetString(params, "output_model", &output_model);
GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
GetString(params, "output_result", &output_result); GetString(params, "output_result", &output_result);
GetString(params, "initscore_filename", &initscore_filename);
if (GetString(params, "valid_data_initscores", &tmp_str)) {
valid_data_initscores = Common::Split(tmp_str.c_str(), ',');
}
GetBool(params, "pre_partition", &pre_partition); GetBool(params, "pre_partition", &pre_partition);
GetBool(params, "enable_bundle", &enable_bundle);
GetDouble(params, "max_conflict_rate", &max_conflict_rate);
CHECK(max_conflict_rate >=0.0);
CHECK(max_conflict_rate <1.0);
GetBool(params, "is_enable_sparse", &is_enable_sparse); GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetDouble(params, "sparse_threshold", &sparse_threshold); GetDouble(params, "sparse_threshold", &sparse_threshold);
CHECK(sparse_threshold >0); CHECK(sparse_threshold >0.0);
CHECK(sparse_threshold <=1); CHECK(sparse_threshold <=1.0);
GetBool(params, "use_missing", &use_missing);
GetBool(params, "zero_as_missing", &zero_as_missing);
GetBool(params, "two_round", &two_round); GetBool(params, "two_round", &two_round);
GetBool(params, "save_binary", &save_binary); GetBool(params, "save_binary", &save_binary);
GetInt(params, "verbosity", &verbosity); GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "header", &header); GetBool(params, "header", &header);
...@@ -373,64 +417,46 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -373,64 +417,46 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin); GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt >0);
GetBool(params, "use_missing", &use_missing);
GetBool(params, "zero_as_missing", &zero_as_missing);
GetString(params, "initscore_filename", &initscore_filename);
if (GetString(params, "valid_data_initscores", &tmp_str)) {
valid_data_initscores = Common::Split(tmp_str.c_str(), ',');
}
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "enable_bundle", &enable_bundle);
GetDouble(params, "max_conflict_rate", &max_conflict_rate);
CHECK(max_conflict_rate >=0);
CHECK(max_conflict_rate <1);
GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "convert_model_language", &convert_model_language); GetString(params, "convert_model_language", &convert_model_language);
GetString(params, "convert_model", &convert_model); GetString(params, "convert_model", &convert_model);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class >0);
GetDouble(params, "sigmoid", &sigmoid); GetBool(params, "is_unbalance", &is_unbalance);
CHECK(sigmoid >0);
GetDouble(params, "alpha", &alpha);
GetDouble(params, "fair_c", &fair_c); GetDouble(params, "scale_pos_weight", &scale_pos_weight);
CHECK(scale_pos_weight >0.0);
GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step); GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid >0.0);
GetBool(params, "boost_from_average", &boost_from_average); GetBool(params, "boost_from_average", &boost_from_average);
GetBool(params, "is_unbalance", &is_unbalance); GetBool(params, "reg_sqrt", &reg_sqrt);
GetDouble(params, "scale_pos_weight", &scale_pos_weight); GetDouble(params, "alpha", &alpha);
CHECK(scale_pos_weight >0); CHECK(alpha >0.0);
CHECK(alpha <1.0);
GetBool(params, "reg_sqrt", &reg_sqrt); GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c >0.0);
GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
CHECK(poisson_max_delta_step >0.0);
GetDouble(params, "tweedie_variance_power", &tweedie_variance_power); GetDouble(params, "tweedie_variance_power", &tweedie_variance_power);
CHECK(tweedie_variance_power >=1.0);
CHECK(tweedie_variance_power <2.0);
GetInt(params, "max_position", &max_position);
CHECK(max_position >0);
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToArray<double>(tmp_str, ','); label_gain = Common::StringToArray<double>(tmp_str, ',');
} }
GetInt(params, "max_position", &max_position);
CHECK(max_position >0);
GetInt(params, "metric_freq", &metric_freq); GetInt(params, "metric_freq", &metric_freq);
CHECK(metric_freq >0); CHECK(metric_freq >0);
...@@ -441,10 +467,13 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -441,10 +467,13 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
} }
GetInt(params, "num_machines", &num_machines); GetInt(params, "num_machines", &num_machines);
CHECK(num_machines >0);
GetInt(params, "local_listen_port", &local_listen_port); GetInt(params, "local_listen_port", &local_listen_port);
CHECK(local_listen_port >0);
GetInt(params, "time_out", &time_out); GetInt(params, "time_out", &time_out);
CHECK(time_out >0);
GetString(params, "machine_list_filename", &machine_list_filename); GetString(params, "machine_list_filename", &machine_list_filename);
...@@ -495,18 +524,28 @@ std::string Config::SaveMembersToString() const { ...@@ -495,18 +524,28 @@ std::string Config::SaveMembersToString() const {
str_buf << "[top_k: " << top_k << "]\n"; str_buf << "[top_k: " << top_k << "]\n";
str_buf << "[monotone_constraints: " << Common::Join(Common::ArrayCast<int8_t, int>(monotone_constraints),",") << "]\n"; str_buf << "[monotone_constraints: " << Common::Join(Common::ArrayCast<int8_t, int>(monotone_constraints),",") << "]\n";
str_buf << "[forcedsplits_filename: " << forcedsplits_filename << "]\n"; str_buf << "[forcedsplits_filename: " << forcedsplits_filename << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n"; str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n"; str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
str_buf << "[bin_construct_sample_cnt: " << bin_construct_sample_cnt << "]\n";
str_buf << "[histogram_pool_size: " << histogram_pool_size << "]\n";
str_buf << "[data_random_seed: " << data_random_seed << "]\n"; str_buf << "[data_random_seed: " << data_random_seed << "]\n";
str_buf << "[output_model: " << output_model << "]\n"; str_buf << "[output_model: " << output_model << "]\n";
str_buf << "[snapshot_freq: " << snapshot_freq << "]\n";
str_buf << "[input_model: " << input_model << "]\n"; str_buf << "[input_model: " << input_model << "]\n";
str_buf << "[output_result: " << output_result << "]\n"; str_buf << "[output_result: " << output_result << "]\n";
str_buf << "[initscore_filename: " << initscore_filename << "]\n";
str_buf << "[valid_data_initscores: " << Common::Join(valid_data_initscores,",") << "]\n";
str_buf << "[pre_partition: " << pre_partition << "]\n"; str_buf << "[pre_partition: " << pre_partition << "]\n";
str_buf << "[enable_bundle: " << enable_bundle << "]\n";
str_buf << "[max_conflict_rate: " << max_conflict_rate << "]\n";
str_buf << "[is_enable_sparse: " << is_enable_sparse << "]\n"; str_buf << "[is_enable_sparse: " << is_enable_sparse << "]\n";
str_buf << "[sparse_threshold: " << sparse_threshold << "]\n"; str_buf << "[sparse_threshold: " << sparse_threshold << "]\n";
str_buf << "[use_missing: " << use_missing << "]\n";
str_buf << "[zero_as_missing: " << zero_as_missing << "]\n";
str_buf << "[two_round: " << two_round << "]\n"; str_buf << "[two_round: " << two_round << "]\n";
str_buf << "[save_binary: " << save_binary << "]\n"; str_buf << "[save_binary: " << save_binary << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n"; str_buf << "[enable_load_from_binary_file: " << enable_load_from_binary_file << "]\n";
str_buf << "[header: " << header << "]\n"; str_buf << "[header: " << header << "]\n";
str_buf << "[label_column: " << label_column << "]\n"; str_buf << "[label_column: " << label_column << "]\n";
str_buf << "[weight_column: " << weight_column << "]\n"; str_buf << "[weight_column: " << weight_column << "]\n";
...@@ -520,30 +559,20 @@ std::string Config::SaveMembersToString() const { ...@@ -520,30 +559,20 @@ std::string Config::SaveMembersToString() const {
str_buf << "[pred_early_stop: " << pred_early_stop << "]\n"; str_buf << "[pred_early_stop: " << pred_early_stop << "]\n";
str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n"; str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n";
str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n"; str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n";
str_buf << "[bin_construct_sample_cnt: " << bin_construct_sample_cnt << "]\n";
str_buf << "[use_missing: " << use_missing << "]\n";
str_buf << "[zero_as_missing: " << zero_as_missing << "]\n";
str_buf << "[initscore_filename: " << initscore_filename << "]\n";
str_buf << "[valid_data_initscores: " << Common::Join(valid_data_initscores,",") << "]\n";
str_buf << "[histogram_pool_size: " << histogram_pool_size << "]\n";
str_buf << "[enable_load_from_binary_file: " << enable_load_from_binary_file << "]\n";
str_buf << "[enable_bundle: " << enable_bundle << "]\n";
str_buf << "[max_conflict_rate: " << max_conflict_rate << "]\n";
str_buf << "[snapshot_freq: " << snapshot_freq << "]\n";
str_buf << "[convert_model_language: " << convert_model_language << "]\n"; str_buf << "[convert_model_language: " << convert_model_language << "]\n";
str_buf << "[convert_model: " << convert_model << "]\n"; str_buf << "[convert_model: " << convert_model << "]\n";
str_buf << "[num_class: " << num_class << "]\n"; str_buf << "[num_class: " << num_class << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n";
str_buf << "[scale_pos_weight: " << scale_pos_weight << "]\n";
str_buf << "[sigmoid: " << sigmoid << "]\n"; str_buf << "[sigmoid: " << sigmoid << "]\n";
str_buf << "[boost_from_average: " << boost_from_average << "]\n";
str_buf << "[reg_sqrt: " << reg_sqrt << "]\n";
str_buf << "[alpha: " << alpha << "]\n"; str_buf << "[alpha: " << alpha << "]\n";
str_buf << "[fair_c: " << fair_c << "]\n"; str_buf << "[fair_c: " << fair_c << "]\n";
str_buf << "[poisson_max_delta_step: " << poisson_max_delta_step << "]\n"; str_buf << "[poisson_max_delta_step: " << poisson_max_delta_step << "]\n";
str_buf << "[boost_from_average: " << boost_from_average << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n";
str_buf << "[scale_pos_weight: " << scale_pos_weight << "]\n";
str_buf << "[reg_sqrt: " << reg_sqrt << "]\n";
str_buf << "[tweedie_variance_power: " << tweedie_variance_power << "]\n"; str_buf << "[tweedie_variance_power: " << tweedie_variance_power << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain,",") << "]\n";
str_buf << "[max_position: " << max_position << "]\n"; str_buf << "[max_position: " << max_position << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain,",") << "]\n";
str_buf << "[metric_freq: " << metric_freq << "]\n"; str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n"; str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at,",") << "]\n"; str_buf << "[eval_at: " << Common::Join(eval_at,",") << "]\n";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment