Unverified Commit 7d700cd3 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

use specific checks in auto config (#2864)

parent b70636bc
......@@ -123,6 +123,34 @@ def get_alias(infos):
return pairs
def parse_check(check, reverse=False):
"""Parse the constraint.
Parameters
----------
check : string
String representation of the constraint.
reverse : bool, optional (default=False)
Whether to reverse the sign of the constraint.
Returns
-------
pair : tuple
Parsed constraint in the form of tuple (value, sign).
"""
try:
idx = 1
float(check[idx:])
except ValueError:
idx = 2
float(check[idx:])
if reverse:
reversed_sign = {'<': '>', '>': '<', '<=': '>=', '>=': '<='}
return check[idx:], reversed_sign[check[:idx]]
else:
return check[idx:], check[:idx]
def set_one_var_from_string(name, param_type, checks):
"""Construct code for auto config file for one param value.
......@@ -145,8 +173,10 @@ def set_one_var_from_string(name, param_type, checks):
if "vector" not in param_type:
ret += " %s(params, \"%s\", &%s);\n" % (univar_mapper[param_type], name, name)
if len(checks) > 0:
check_mapper = {"<": "LT", ">": "GT", "<=": "LE", ">=": "GE"}
for check in checks:
ret += " CHECK(%s %s);\n" % (name, check)
value, sign = parse_check(check)
ret += " CHECK_%s(%s, %s);\n" % (check_mapper[sign], name, value)
ret += "\n"
else:
ret += " if (GetString(params, \"%s\", &tmp_str)) {\n" % (name)
......@@ -171,33 +201,6 @@ def gen_parameter_description(sections, descriptions, params_rst):
params_rst : string
Path to the file with parameters documentation.
"""
def parse_check(check, reverse=False):
"""Parse the constraint.
Parameters
----------
check : string
String representation of the constraint.
reverse : bool, optional (default=False)
Whether to reverse the sign of the constraint.
Returns
-------
pair : tuple
Parsed constraint in the form of tuple (value, sign).
"""
try:
idx = 1
float(check[idx:])
except ValueError:
idx = 2
float(check[idx:])
if reverse:
reversed_sign = {'<': '>', '>': '<', '<=': '>=', '>=': '<='}
return check[idx:], reversed_sign[check[:idx]]
else:
return check[idx:], check[:idx]
params_to_write = []
lvl_mapper = {1: '-', 2: '~'}
for (section_name, section_lvl), section_params in zip(sections, descriptions):
......
......@@ -298,14 +298,14 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
}
GetInt(params, "num_iterations", &num_iterations);
CHECK(num_iterations >=0);
CHECK_GE(num_iterations, 0);
GetDouble(params, "learning_rate", &learning_rate);
CHECK(learning_rate >0.0);
CHECK_GT(learning_rate, 0.0);
GetInt(params, "num_leaves", &num_leaves);
CHECK(num_leaves >1);
CHECK(num_leaves <=131072);
CHECK_GT(num_leaves, 1);
CHECK_LE(num_leaves, 131072);
GetInt(params, "num_threads", &num_threads);
......@@ -318,34 +318,34 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "max_depth", &max_depth);
GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
CHECK(min_data_in_leaf >=0);
CHECK_GE(min_data_in_leaf, 0);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf >=0.0);
CHECK_GE(min_sum_hessian_in_leaf, 0.0);
GetDouble(params, "bagging_fraction", &bagging_fraction);
CHECK(bagging_fraction >0.0);
CHECK(bagging_fraction <=1.0);
CHECK_GT(bagging_fraction, 0.0);
CHECK_LE(bagging_fraction, 1.0);
GetDouble(params, "pos_bagging_fraction", &pos_bagging_fraction);
CHECK(pos_bagging_fraction >0.0);
CHECK(pos_bagging_fraction <=1.0);
CHECK_GT(pos_bagging_fraction, 0.0);
CHECK_LE(pos_bagging_fraction, 1.0);
GetDouble(params, "neg_bagging_fraction", &neg_bagging_fraction);
CHECK(neg_bagging_fraction >0.0);
CHECK(neg_bagging_fraction <=1.0);
CHECK_GT(neg_bagging_fraction, 0.0);
CHECK_LE(neg_bagging_fraction, 1.0);
GetInt(params, "bagging_freq", &bagging_freq);
GetInt(params, "bagging_seed", &bagging_seed);
GetDouble(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction >0.0);
CHECK(feature_fraction <=1.0);
CHECK_GT(feature_fraction, 0.0);
CHECK_LE(feature_fraction, 1.0);
GetDouble(params, "feature_fraction_bynode", &feature_fraction_bynode);
CHECK(feature_fraction_bynode >0.0);
CHECK(feature_fraction_bynode <=1.0);
CHECK_GT(feature_fraction_bynode, 0.0);
CHECK_LE(feature_fraction_bynode, 1.0);
GetInt(params, "feature_fraction_seed", &feature_fraction_seed);
......@@ -360,23 +360,23 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetDouble(params, "max_delta_step", &max_delta_step);
GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >=0.0);
CHECK_GE(lambda_l1, 0.0);
GetDouble(params, "lambda_l2", &lambda_l2);
CHECK(lambda_l2 >=0.0);
CHECK_GE(lambda_l2, 0.0);
GetDouble(params, "min_gain_to_split", &min_gain_to_split);
CHECK(min_gain_to_split >=0.0);
CHECK_GE(min_gain_to_split, 0.0);
GetDouble(params, "drop_rate", &drop_rate);
CHECK(drop_rate >=0.0);
CHECK(drop_rate <=1.0);
CHECK_GE(drop_rate, 0.0);
CHECK_LE(drop_rate, 1.0);
GetInt(params, "max_drop", &max_drop);
GetDouble(params, "skip_drop", &skip_drop);
CHECK(skip_drop >=0.0);
CHECK(skip_drop <=1.0);
CHECK_GE(skip_drop, 0.0);
CHECK_LE(skip_drop, 1.0);
GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
......@@ -385,30 +385,30 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "drop_seed", &drop_seed);
GetDouble(params, "top_rate", &top_rate);
CHECK(top_rate >=0.0);
CHECK(top_rate <=1.0);
CHECK_GE(top_rate, 0.0);
CHECK_LE(top_rate, 1.0);
GetDouble(params, "other_rate", &other_rate);
CHECK(other_rate >=0.0);
CHECK(other_rate <=1.0);
CHECK_GE(other_rate, 0.0);
CHECK_LE(other_rate, 1.0);
GetInt(params, "min_data_per_group", &min_data_per_group);
CHECK(min_data_per_group >0);
CHECK_GT(min_data_per_group, 0);
GetInt(params, "max_cat_threshold", &max_cat_threshold);
CHECK(max_cat_threshold >0);
CHECK_GT(max_cat_threshold, 0);
GetDouble(params, "cat_l2", &cat_l2);
CHECK(cat_l2 >=0.0);
CHECK_GE(cat_l2, 0.0);
GetDouble(params, "cat_smooth", &cat_smooth);
CHECK(cat_smooth >=0.0);
CHECK_GE(cat_smooth, 0.0);
GetInt(params, "max_cat_to_onehot", &max_cat_to_onehot);
CHECK(max_cat_to_onehot >0);
CHECK_GT(max_cat_to_onehot, 0);
GetInt(params, "top_k", &top_k);
CHECK(top_k >0);
CHECK_GT(top_k, 0);
if (GetString(params, "monotone_constraints", &tmp_str)) {
monotone_constraints = Common::StringToArray<int8_t>(tmp_str, ',');
......@@ -421,14 +421,14 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetString(params, "forcedsplits_filename", &forcedsplits_filename);
GetDouble(params, "refit_decay_rate", &refit_decay_rate);
CHECK(refit_decay_rate >=0.0);
CHECK(refit_decay_rate <=1.0);
CHECK_GE(refit_decay_rate, 0.0);
CHECK_LE(refit_decay_rate, 1.0);
GetDouble(params, "cegb_tradeoff", &cegb_tradeoff);
CHECK(cegb_tradeoff >=0.0);
CHECK_GE(cegb_tradeoff, 0.0);
GetDouble(params, "cegb_penalty_split", &cegb_penalty_split);
CHECK(cegb_penalty_split >=0.0);
CHECK_GE(cegb_penalty_split, 0.0);
if (GetString(params, "cegb_penalty_feature_lazy", &tmp_str)) {
cegb_penalty_feature_lazy = Common::StringToArray<double>(tmp_str, ',');
......@@ -447,17 +447,17 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "snapshot_freq", &snapshot_freq);
GetInt(params, "max_bin", &max_bin);
CHECK(max_bin >1);
CHECK_GT(max_bin, 1);
if (GetString(params, "max_bin_by_feature", &tmp_str)) {
max_bin_by_feature = Common::StringToArray<int32_t>(tmp_str, ',');
}
GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin >0);
CHECK_GT(min_data_in_bin, 0);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt >0);
CHECK_GT(bin_construct_sample_cnt, 0);
GetInt(params, "data_random_seed", &data_random_seed);
......@@ -516,35 +516,35 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "objective_seed", &objective_seed);
GetInt(params, "num_class", &num_class);
CHECK(num_class >0);
CHECK_GT(num_class, 0);
GetBool(params, "is_unbalance", &is_unbalance);
GetDouble(params, "scale_pos_weight", &scale_pos_weight);
CHECK(scale_pos_weight >0.0);
CHECK_GT(scale_pos_weight, 0.0);
GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid >0.0);
CHECK_GT(sigmoid, 0.0);
GetBool(params, "boost_from_average", &boost_from_average);
GetBool(params, "reg_sqrt", &reg_sqrt);
GetDouble(params, "alpha", &alpha);
CHECK(alpha >0.0);
CHECK_GT(alpha, 0.0);
GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c >0.0);
CHECK_GT(fair_c, 0.0);
GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
CHECK(poisson_max_delta_step >0.0);
CHECK_GT(poisson_max_delta_step, 0.0);
GetDouble(params, "tweedie_variance_power", &tweedie_variance_power);
CHECK(tweedie_variance_power >=1.0);
CHECK(tweedie_variance_power <2.0);
CHECK_GE(tweedie_variance_power, 1.0);
CHECK_LT(tweedie_variance_power, 2.0);
GetInt(params, "lambdarank_truncation_level", &lambdarank_truncation_level);
CHECK(lambdarank_truncation_level >0);
CHECK_GT(lambdarank_truncation_level, 0);
GetBool(params, "lambdarank_norm", &lambdarank_norm);
......@@ -553,7 +553,7 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
}
GetInt(params, "metric_freq", &metric_freq);
CHECK(metric_freq >0);
CHECK_GT(metric_freq, 0);
GetBool(params, "is_provide_training_metric", &is_provide_training_metric);
......@@ -562,20 +562,20 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
}
GetInt(params, "multi_error_top_k", &multi_error_top_k);
CHECK(multi_error_top_k >0);
CHECK_GT(multi_error_top_k, 0);
if (GetString(params, "auc_mu_weights", &tmp_str)) {
auc_mu_weights = Common::StringToArray<double>(tmp_str, ',');
}
GetInt(params, "num_machines", &num_machines);
CHECK(num_machines >0);
CHECK_GT(num_machines, 0);
GetInt(params, "local_listen_port", &local_listen_port);
CHECK(local_listen_port >0);
CHECK_GT(local_listen_port, 0);
GetInt(params, "time_out", &time_out);
CHECK(time_out >0);
CHECK_GT(time_out, 0);
GetString(params, "machine_list_filename", &machine_list_filename);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment