Unverified Commit 83627ff0 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

reproducible parameter alias resolution for wrappers (fixes #5304) (#5338)

* dump sorted parameter aliases

* update lgb.check.wrapper_param

* update _choose_param_value to look like lgb.check.wrapper_param

* apply suggestions from review

* reduce diff

* move DumpAliases to config

* remove unnecessary check

* restore parameter check
parent 212d1457
...@@ -177,7 +177,7 @@ lgb.check.eval <- function(params, eval) { ...@@ -177,7 +177,7 @@ lgb.check.eval <- function(params, eval) {
# ways, the first item in this list is used: # ways, the first item in this list is used:
# #
# 1. the main (non-alias) parameter found in `params` # 1. the main (non-alias) parameter found in `params`
# 2. the first alias of that parameter found in `params` # 2. the alias with the highest priority found in `params`
# 3. the keyword argument passed in # 3. the keyword argument passed in
# #
# For example, "num_iterations" can also be provided to lgb.train() # For example, "num_iterations" can also be provided to lgb.train()
...@@ -185,7 +185,7 @@ lgb.check.eval <- function(params, eval) { ...@@ -185,7 +185,7 @@ lgb.check.eval <- function(params, eval) {
# based on the first match in this list: # based on the first match in this list:
# #
# 1. params[["num_iterations]] # 1. params[["num_iterations]]
# 2. the first alias of "num_iterations" found in params # 2. the highest priority alias of "num_iterations" found in params
# 3. the nrounds keyword argument # 3. the nrounds keyword argument
# #
# If multiple aliases are found in `params` for the same parameter, they are # If multiple aliases are found in `params` for the same parameter, they are
...@@ -197,7 +197,7 @@ lgb.check.eval <- function(params, eval) { ...@@ -197,7 +197,7 @@ lgb.check.eval <- function(params, eval) {
lgb.check.wrapper_param <- function(main_param_name, params, alternative_kwarg_value) { lgb.check.wrapper_param <- function(main_param_name, params, alternative_kwarg_value) {
aliases <- .PARAMETER_ALIASES()[[main_param_name]] aliases <- .PARAMETER_ALIASES()[[main_param_name]]
aliases_provided <- names(params)[names(params) %in% aliases] aliases_provided <- aliases[aliases %in% names(params)]
aliases_provided <- aliases_provided[aliases_provided != main_param_name] aliases_provided <- aliases_provided[aliases_provided != main_param_name]
# prefer the main parameter # prefer the main parameter
......
...@@ -56,6 +56,7 @@ test_that(".PARAMETER_ALIASES() returns a named list of character vectors, where ...@@ -56,6 +56,7 @@ test_that(".PARAMETER_ALIASES() returns a named list of character vectors, where
expect_true(all(sapply(param_aliases, is.character))) expect_true(all(sapply(param_aliases, is.character)))
expect_true(length(unique(names(param_aliases))) == length(param_aliases)) expect_true(length(unique(names(param_aliases))) == length(param_aliases))
expect_equal(sort(param_aliases[["task"]]), c("task", "task_type")) expect_equal(sort(param_aliases[["task"]]), c("task", "task_type"))
expect_equal(param_aliases[["bagging_fraction"]], c("bagging_fraction", "bagging", "sub_row", "subsample"))
}) })
test_that(".PARAMETER_ALIASES() uses the internal session cache", { test_that(".PARAMETER_ALIASES() uses the internal session cache", {
......
...@@ -123,7 +123,7 @@ test_that("lgb.check.wrapper_param() prefers alias to keyword arg", { ...@@ -123,7 +123,7 @@ test_that("lgb.check.wrapper_param() prefers alias to keyword arg", {
expect_equal(params[["num_iterations"]], num_tree) expect_equal(params[["num_iterations"]], num_tree)
expect_identical(params, list(num_iterations = num_tree)) expect_identical(params, list(num_iterations = num_tree))
# switching the order should switch which one is chosen # switching the order shouldn't switch which one is chosen
params2 <- lgb.check.wrapper_param( params2 <- lgb.check.wrapper_param(
main_param_name = "num_iterations" main_param_name = "num_iterations"
, params = list( , params = list(
...@@ -132,6 +132,6 @@ test_that("lgb.check.wrapper_param() prefers alias to keyword arg", { ...@@ -132,6 +132,6 @@ test_that("lgb.check.wrapper_param() prefers alias to keyword arg", {
) )
, alternative_kwarg_value = kwarg_val , alternative_kwarg_value = kwarg_val
) )
expect_equal(params2[["num_iterations"]], n_estimators) expect_equal(params2[["num_iterations"]], num_tree)
expect_identical(params2, list(num_iterations = n_estimators)) expect_identical(params2, list(num_iterations = num_tree))
}) })
...@@ -359,20 +359,20 @@ def gen_parameter_code( ...@@ -359,20 +359,20 @@ def gen_parameter_code(
str_to_write += " return str_buf.str();\n" str_to_write += " return str_buf.str();\n"
str_to_write += "}\n\n" str_to_write += "}\n\n"
str_to_write += "const std::string Config::DumpAliases() {\n" str_to_write += """const std::unordered_map<std::string, std::vector<std::string>>& Config::parameter2aliases() {
str_to_write += " std::stringstream str_buf;\n" static std::unordered_map<std::string, std::vector<std::string>> map({"""
str_to_write += ' str_buf << "{";\n' for name in names:
for idx, name in enumerate(names): str_to_write += '\n {"' + name + '", '
if idx > 0: if names_with_aliases[name]:
str_to_write += ', ";\n' str_to_write += '{"' + '", "'.join(names_with_aliases[name]) + '"}},'
aliases = '\\", \\"'.join([alias for alias in names_with_aliases[name]]) else:
aliases = f'[\\"{aliases}\\"]' if aliases else '[]' str_to_write += '{}},'
str_to_write += f' str_buf << "\\"{name}\\": {aliases}' str_to_write += """
str_to_write += '";\n' });
str_to_write += ' str_buf << "}";\n' return map;
str_to_write += " return str_buf.str();\n" }
str_to_write += "}\n\n"
"""
str_to_write += "} // namespace LightGBM\n" str_to_write += "} // namespace LightGBM\n"
with open(config_out_cpp, "w") as config_out_cpp_file: with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write) config_out_cpp_file.write(str_to_write)
......
...@@ -78,6 +78,14 @@ struct Config { ...@@ -78,6 +78,14 @@ struct Config {
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, bool* out); const std::string& name, bool* out);
/*!
* \brief Sort aliases by length and then alphabetically
* \param x Alias 1
* \param y Alias 2
* \return true if x has higher priority than y
*/
inline static bool SortAlias(const std::string& x, const std::string& y);
static void KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv); static void KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv);
static std::unordered_map<std::string, std::string> Str2Map(const char* parameters); static std::unordered_map<std::string, std::string> Str2Map(const char* parameters);
...@@ -1063,6 +1071,7 @@ struct Config { ...@@ -1063,6 +1071,7 @@ struct Config {
bool is_data_based_parallel = false; bool is_data_based_parallel = false;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params); LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params);
static const std::unordered_map<std::string, std::string>& alias_table(); static const std::unordered_map<std::string, std::string>& alias_table();
static const std::unordered_map<std::string, std::vector<std::string>>& parameter2aliases();
static const std::unordered_set<std::string>& parameter_set(); static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix; std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector; std::vector<std::vector<int>> interaction_constraints_vector;
...@@ -1131,6 +1140,10 @@ inline bool Config::GetBool( ...@@ -1131,6 +1140,10 @@ inline bool Config::GetBool(
return false; return false;
} }
inline bool Config::SortAlias(const std::string& x, const std::string& y) {
return x.size() < y.size() || (x.size() == y.size() && x < y);
}
struct ParameterAlias { struct ParameterAlias {
static void KeyAliasTransform(std::unordered_map<std::string, std::string>* params) { static void KeyAliasTransform(std::unordered_map<std::string, std::string>* params) {
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
...@@ -1139,9 +1152,7 @@ struct ParameterAlias { ...@@ -1139,9 +1152,7 @@ struct ParameterAlias {
if (alias != Config::alias_table().end()) { // found alias if (alias != Config::alias_table().end()) { // found alias
auto alias_set = tmp_map.find(alias->second); auto alias_set = tmp_map.find(alias->second);
if (alias_set != tmp_map.end()) { // alias already set if (alias_set != tmp_map.end()) { // alias already set
// set priority by length & alphabetically to ensure reproducible behavior if (Config::SortAlias(alias_set->second, pair.first)) {
if (alias_set->second.size() < pair.first.size() ||
(alias_set->second.size() == pair.first.size() && alias_set->second < pair.first)) {
Log::Warning("%s is set with %s=%s, %s=%s will be ignored. Current value: %s=%s", Log::Warning("%s is set with %s=%s, %s=%s will be ignored. Current value: %s=%s",
alias->second.c_str(), alias_set->second.c_str(), params->at(alias_set->second).c_str(), alias->second.c_str(), alias_set->second.c_str(), params->at(alias_set->second).c_str(),
pair.first.c_str(), pair.second.c_str(), alias->second.c_str(), params->at(alias_set->second).c_str()); pair.first.c_str(), pair.second.c_str(), alias->second.c_str(), params->at(alias_set->second).c_str());
......
...@@ -345,7 +345,7 @@ class _ConfigAliases: ...@@ -345,7 +345,7 @@ class _ConfigAliases:
aliases = None aliases = None
@staticmethod @staticmethod
def _get_all_param_aliases() -> Dict[str, Set[str]]: def _get_all_param_aliases() -> Dict[str, List[str]]:
buffer_len = 1 << 20 buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len) string_buffer = ctypes.create_string_buffer(buffer_len)
...@@ -365,7 +365,7 @@ class _ConfigAliases: ...@@ -365,7 +365,7 @@ class _ConfigAliases:
ptr_string_buffer)) ptr_string_buffer))
aliases = json.loads( aliases = json.loads(
string_buffer.value.decode('utf-8'), string_buffer.value.decode('utf-8'),
object_hook=lambda obj: {k: set(v) | {k} for k, v in obj.items()} object_hook=lambda obj: {k: [k] + v for k, v in obj.items()}
) )
return aliases return aliases
...@@ -375,9 +375,15 @@ class _ConfigAliases: ...@@ -375,9 +375,15 @@ class _ConfigAliases:
cls.aliases = cls._get_all_param_aliases() cls.aliases = cls._get_all_param_aliases()
ret = set() ret = set()
for i in args: for i in args:
ret |= cls.aliases.get(i, {i}) ret.update(cls.get_sorted(i))
return ret return ret
@classmethod
def get_sorted(cls, name: str) -> List[str]:
if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases()
return cls.aliases.get(name, [name])
@classmethod @classmethod
def get_by_alias(cls, *args) -> Set[str]: def get_by_alias(cls, *args) -> Set[str]:
if cls.aliases is None: if cls.aliases is None:
...@@ -386,7 +392,7 @@ class _ConfigAliases: ...@@ -386,7 +392,7 @@ class _ConfigAliases:
for arg in args: for arg in args:
for aliases in cls.aliases.values(): for aliases in cls.aliases.values():
if arg in aliases: if arg in aliases:
ret |= aliases ret.update(aliases)
break break
return ret return ret
...@@ -412,7 +418,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va ...@@ -412,7 +418,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
# avoid side effects on passed-in parameters # avoid side effects on passed-in parameters
params = deepcopy(params) params = deepcopy(params)
aliases = _ConfigAliases.get(main_param_name) - {main_param_name} aliases = _ConfigAliases.get_sorted(main_param_name)
aliases = [a for a in aliases if a != main_param_name]
# if main_param_name was provided, keep that value and remove all aliases # if main_param_name was provided, keep that value and remove all aliases
if main_param_name in params.keys(): if main_param_name in params.keys():
......
...@@ -411,4 +411,29 @@ std::string Config::ToString() const { ...@@ -411,4 +411,29 @@ std::string Config::ToString() const {
return str_buf.str(); return str_buf.str();
} }
const std::string Config::DumpAliases() {
auto map = Config::parameter2aliases();
for (auto& pair : map) {
std::sort(pair.second.begin(), pair.second.end(), SortAlias);
}
std::stringstream str_buf;
str_buf << "{\n";
bool first = true;
for (const auto& pair : map) {
if (first) {
str_buf << " \"";
first = false;
} else {
str_buf << " , \"";
}
str_buf << pair.first << "\": [";
if (pair.second.size() > 0) {
str_buf << "\"" << CommonC::Join(pair.second, "\", \"") << "\"";
}
str_buf << "]\n";
}
str_buf << "}\n";
return str_buf.str();
}
} // namespace LightGBM } // namespace LightGBM
...@@ -756,143 +756,142 @@ std::string Config::SaveMembersToString() const { ...@@ -756,143 +756,142 @@ std::string Config::SaveMembersToString() const {
return str_buf.str(); return str_buf.str();
} }
const std::string Config::DumpAliases() { const std::unordered_map<std::string, std::vector<std::string>>& Config::parameter2aliases() {
std::stringstream str_buf; static std::unordered_map<std::string, std::vector<std::string>> map({
str_buf << "{"; {"config", {"config_file"}},
str_buf << "\"config\": [\"config_file\"], "; {"task", {"task_type"}},
str_buf << "\"task\": [\"task_type\"], "; {"objective", {"objective_type", "app", "application", "loss"}},
str_buf << "\"objective\": [\"objective_type\", \"app\", \"application\", \"loss\"], "; {"boosting", {"boosting_type", "boost"}},
str_buf << "\"boosting\": [\"boosting_type\", \"boost\"], "; {"data", {"train", "train_data", "train_data_file", "data_filename"}},
str_buf << "\"data\": [\"train\", \"train_data\", \"train_data_file\", \"data_filename\"], "; {"valid", {"test", "valid_data", "valid_data_file", "test_data", "test_data_file", "valid_filenames"}},
str_buf << "\"valid\": [\"test\", \"valid_data\", \"valid_data_file\", \"test_data\", \"test_data_file\", \"valid_filenames\"], "; {"num_iterations", {"num_iteration", "n_iter", "num_tree", "num_trees", "num_round", "num_rounds", "nrounds", "num_boost_round", "n_estimators", "max_iter"}},
str_buf << "\"num_iterations\": [\"num_iteration\", \"n_iter\", \"num_tree\", \"num_trees\", \"num_round\", \"num_rounds\", \"nrounds\", \"num_boost_round\", \"n_estimators\", \"max_iter\"], "; {"learning_rate", {"shrinkage_rate", "eta"}},
str_buf << "\"learning_rate\": [\"shrinkage_rate\", \"eta\"], "; {"num_leaves", {"num_leaf", "max_leaves", "max_leaf", "max_leaf_nodes"}},
str_buf << "\"num_leaves\": [\"num_leaf\", \"max_leaves\", \"max_leaf\", \"max_leaf_nodes\"], "; {"tree_learner", {"tree", "tree_type", "tree_learner_type"}},
str_buf << "\"tree_learner\": [\"tree\", \"tree_type\", \"tree_learner_type\"], "; {"num_threads", {"num_thread", "nthread", "nthreads", "n_jobs"}},
str_buf << "\"num_threads\": [\"num_thread\", \"nthread\", \"nthreads\", \"n_jobs\"], "; {"device_type", {"device"}},
str_buf << "\"device_type\": [\"device\"], "; {"seed", {"random_seed", "random_state"}},
str_buf << "\"seed\": [\"random_seed\", \"random_state\"], "; {"deterministic", {}},
str_buf << "\"deterministic\": [], "; {"force_col_wise", {}},
str_buf << "\"force_col_wise\": [], "; {"force_row_wise", {}},
str_buf << "\"force_row_wise\": [], "; {"histogram_pool_size", {"hist_pool_size"}},
str_buf << "\"histogram_pool_size\": [\"hist_pool_size\"], "; {"max_depth", {}},
str_buf << "\"max_depth\": [], "; {"min_data_in_leaf", {"min_data_per_leaf", "min_data", "min_child_samples", "min_samples_leaf"}},
str_buf << "\"min_data_in_leaf\": [\"min_data_per_leaf\", \"min_data\", \"min_child_samples\", \"min_samples_leaf\"], "; {"min_sum_hessian_in_leaf", {"min_sum_hessian_per_leaf", "min_sum_hessian", "min_hessian", "min_child_weight"}},
str_buf << "\"min_sum_hessian_in_leaf\": [\"min_sum_hessian_per_leaf\", \"min_sum_hessian\", \"min_hessian\", \"min_child_weight\"], "; {"bagging_fraction", {"sub_row", "subsample", "bagging"}},
str_buf << "\"bagging_fraction\": [\"sub_row\", \"subsample\", \"bagging\"], "; {"pos_bagging_fraction", {"pos_sub_row", "pos_subsample", "pos_bagging"}},
str_buf << "\"pos_bagging_fraction\": [\"pos_sub_row\", \"pos_subsample\", \"pos_bagging\"], "; {"neg_bagging_fraction", {"neg_sub_row", "neg_subsample", "neg_bagging"}},
str_buf << "\"neg_bagging_fraction\": [\"neg_sub_row\", \"neg_subsample\", \"neg_bagging\"], "; {"bagging_freq", {"subsample_freq"}},
str_buf << "\"bagging_freq\": [\"subsample_freq\"], "; {"bagging_seed", {"bagging_fraction_seed"}},
str_buf << "\"bagging_seed\": [\"bagging_fraction_seed\"], "; {"feature_fraction", {"sub_feature", "colsample_bytree"}},
str_buf << "\"feature_fraction\": [\"sub_feature\", \"colsample_bytree\"], "; {"feature_fraction_bynode", {"sub_feature_bynode", "colsample_bynode"}},
str_buf << "\"feature_fraction_bynode\": [\"sub_feature_bynode\", \"colsample_bynode\"], "; {"feature_fraction_seed", {}},
str_buf << "\"feature_fraction_seed\": [], "; {"extra_trees", {"extra_tree"}},
str_buf << "\"extra_trees\": [\"extra_tree\"], "; {"extra_seed", {}},
str_buf << "\"extra_seed\": [], "; {"early_stopping_round", {"early_stopping_rounds", "early_stopping", "n_iter_no_change"}},
str_buf << "\"early_stopping_round\": [\"early_stopping_rounds\", \"early_stopping\", \"n_iter_no_change\"], "; {"first_metric_only", {}},
str_buf << "\"first_metric_only\": [], "; {"max_delta_step", {"max_tree_output", "max_leaf_output"}},
str_buf << "\"max_delta_step\": [\"max_tree_output\", \"max_leaf_output\"], "; {"lambda_l1", {"reg_alpha", "l1_regularization"}},
str_buf << "\"lambda_l1\": [\"reg_alpha\", \"l1_regularization\"], "; {"lambda_l2", {"reg_lambda", "lambda", "l2_regularization"}},
str_buf << "\"lambda_l2\": [\"reg_lambda\", \"lambda\", \"l2_regularization\"], "; {"linear_lambda", {}},
str_buf << "\"linear_lambda\": [], "; {"min_gain_to_split", {"min_split_gain"}},
str_buf << "\"min_gain_to_split\": [\"min_split_gain\"], "; {"drop_rate", {"rate_drop"}},
str_buf << "\"drop_rate\": [\"rate_drop\"], "; {"max_drop", {}},
str_buf << "\"max_drop\": [], "; {"skip_drop", {}},
str_buf << "\"skip_drop\": [], "; {"xgboost_dart_mode", {}},
str_buf << "\"xgboost_dart_mode\": [], "; {"uniform_drop", {}},
str_buf << "\"uniform_drop\": [], "; {"drop_seed", {}},
str_buf << "\"drop_seed\": [], "; {"top_rate", {}},
str_buf << "\"top_rate\": [], "; {"other_rate", {}},
str_buf << "\"other_rate\": [], "; {"min_data_per_group", {}},
str_buf << "\"min_data_per_group\": [], "; {"max_cat_threshold", {}},
str_buf << "\"max_cat_threshold\": [], "; {"cat_l2", {}},
str_buf << "\"cat_l2\": [], "; {"cat_smooth", {}},
str_buf << "\"cat_smooth\": [], "; {"max_cat_to_onehot", {}},
str_buf << "\"max_cat_to_onehot\": [], "; {"top_k", {"topk"}},
str_buf << "\"top_k\": [\"topk\"], "; {"monotone_constraints", {"mc", "monotone_constraint", "monotonic_cst"}},
str_buf << "\"monotone_constraints\": [\"mc\", \"monotone_constraint\", \"monotonic_cst\"], "; {"monotone_constraints_method", {"monotone_constraining_method", "mc_method"}},
str_buf << "\"monotone_constraints_method\": [\"monotone_constraining_method\", \"mc_method\"], "; {"monotone_penalty", {"monotone_splits_penalty", "ms_penalty", "mc_penalty"}},
str_buf << "\"monotone_penalty\": [\"monotone_splits_penalty\", \"ms_penalty\", \"mc_penalty\"], "; {"feature_contri", {"feature_contrib", "fc", "fp", "feature_penalty"}},
str_buf << "\"feature_contri\": [\"feature_contrib\", \"fc\", \"fp\", \"feature_penalty\"], "; {"forcedsplits_filename", {"fs", "forced_splits_filename", "forced_splits_file", "forced_splits"}},
str_buf << "\"forcedsplits_filename\": [\"fs\", \"forced_splits_filename\", \"forced_splits_file\", \"forced_splits\"], "; {"refit_decay_rate", {}},
str_buf << "\"refit_decay_rate\": [], "; {"cegb_tradeoff", {}},
str_buf << "\"cegb_tradeoff\": [], "; {"cegb_penalty_split", {}},
str_buf << "\"cegb_penalty_split\": [], "; {"cegb_penalty_feature_lazy", {}},
str_buf << "\"cegb_penalty_feature_lazy\": [], "; {"cegb_penalty_feature_coupled", {}},
str_buf << "\"cegb_penalty_feature_coupled\": [], "; {"path_smooth", {}},
str_buf << "\"path_smooth\": [], "; {"interaction_constraints", {}},
str_buf << "\"interaction_constraints\": [], "; {"verbosity", {"verbose"}},
str_buf << "\"verbosity\": [\"verbose\"], "; {"input_model", {"model_input", "model_in"}},
str_buf << "\"input_model\": [\"model_input\", \"model_in\"], "; {"output_model", {"model_output", "model_out"}},
str_buf << "\"output_model\": [\"model_output\", \"model_out\"], "; {"saved_feature_importance_type", {}},
str_buf << "\"saved_feature_importance_type\": [], "; {"snapshot_freq", {"save_period"}},
str_buf << "\"snapshot_freq\": [\"save_period\"], "; {"linear_tree", {"linear_trees"}},
str_buf << "\"linear_tree\": [\"linear_trees\"], "; {"max_bin", {"max_bins"}},
str_buf << "\"max_bin\": [\"max_bins\"], "; {"max_bin_by_feature", {}},
str_buf << "\"max_bin_by_feature\": [], "; {"min_data_in_bin", {}},
str_buf << "\"min_data_in_bin\": [], "; {"bin_construct_sample_cnt", {"subsample_for_bin"}},
str_buf << "\"bin_construct_sample_cnt\": [\"subsample_for_bin\"], "; {"data_random_seed", {"data_seed"}},
str_buf << "\"data_random_seed\": [\"data_seed\"], "; {"is_enable_sparse", {"is_sparse", "enable_sparse", "sparse"}},
str_buf << "\"is_enable_sparse\": [\"is_sparse\", \"enable_sparse\", \"sparse\"], "; {"enable_bundle", {"is_enable_bundle", "bundle"}},
str_buf << "\"enable_bundle\": [\"is_enable_bundle\", \"bundle\"], "; {"use_missing", {}},
str_buf << "\"use_missing\": [], "; {"zero_as_missing", {}},
str_buf << "\"zero_as_missing\": [], "; {"feature_pre_filter", {}},
str_buf << "\"feature_pre_filter\": [], "; {"pre_partition", {"is_pre_partition"}},
str_buf << "\"pre_partition\": [\"is_pre_partition\"], "; {"two_round", {"two_round_loading", "use_two_round_loading"}},
str_buf << "\"two_round\": [\"two_round_loading\", \"use_two_round_loading\"], "; {"header", {"has_header"}},
str_buf << "\"header\": [\"has_header\"], "; {"label_column", {"label"}},
str_buf << "\"label_column\": [\"label\"], "; {"weight_column", {"weight"}},
str_buf << "\"weight_column\": [\"weight\"], "; {"group_column", {"group", "group_id", "query_column", "query", "query_id"}},
str_buf << "\"group_column\": [\"group\", \"group_id\", \"query_column\", \"query\", \"query_id\"], "; {"ignore_column", {"ignore_feature", "blacklist"}},
str_buf << "\"ignore_column\": [\"ignore_feature\", \"blacklist\"], "; {"categorical_feature", {"cat_feature", "categorical_column", "cat_column", "categorical_features"}},
str_buf << "\"categorical_feature\": [\"cat_feature\", \"categorical_column\", \"cat_column\", \"categorical_features\"], "; {"forcedbins_filename", {}},
str_buf << "\"forcedbins_filename\": [], "; {"save_binary", {"is_save_binary", "is_save_binary_file"}},
str_buf << "\"save_binary\": [\"is_save_binary\", \"is_save_binary_file\"], "; {"precise_float_parser", {}},
str_buf << "\"precise_float_parser\": [], "; {"parser_config_file", {}},
str_buf << "\"parser_config_file\": [], "; {"start_iteration_predict", {}},
str_buf << "\"start_iteration_predict\": [], "; {"num_iteration_predict", {}},
str_buf << "\"num_iteration_predict\": [], "; {"predict_raw_score", {"is_predict_raw_score", "predict_rawscore", "raw_score"}},
str_buf << "\"predict_raw_score\": [\"is_predict_raw_score\", \"predict_rawscore\", \"raw_score\"], "; {"predict_leaf_index", {"is_predict_leaf_index", "leaf_index"}},
str_buf << "\"predict_leaf_index\": [\"is_predict_leaf_index\", \"leaf_index\"], "; {"predict_contrib", {"is_predict_contrib", "contrib"}},
str_buf << "\"predict_contrib\": [\"is_predict_contrib\", \"contrib\"], "; {"predict_disable_shape_check", {}},
str_buf << "\"predict_disable_shape_check\": [], "; {"pred_early_stop", {}},
str_buf << "\"pred_early_stop\": [], "; {"pred_early_stop_freq", {}},
str_buf << "\"pred_early_stop_freq\": [], "; {"pred_early_stop_margin", {}},
str_buf << "\"pred_early_stop_margin\": [], "; {"output_result", {"predict_result", "prediction_result", "predict_name", "prediction_name", "pred_name", "name_pred"}},
str_buf << "\"output_result\": [\"predict_result\", \"prediction_result\", \"predict_name\", \"prediction_name\", \"pred_name\", \"name_pred\"], "; {"convert_model_language", {}},
str_buf << "\"convert_model_language\": [], "; {"convert_model", {"convert_model_file"}},
str_buf << "\"convert_model\": [\"convert_model_file\"], "; {"objective_seed", {}},
str_buf << "\"objective_seed\": [], "; {"num_class", {"num_classes"}},
str_buf << "\"num_class\": [\"num_classes\"], "; {"is_unbalance", {"unbalance", "unbalanced_sets"}},
str_buf << "\"is_unbalance\": [\"unbalance\", \"unbalanced_sets\"], "; {"scale_pos_weight", {}},
str_buf << "\"scale_pos_weight\": [], "; {"sigmoid", {}},
str_buf << "\"sigmoid\": [], "; {"boost_from_average", {}},
str_buf << "\"boost_from_average\": [], "; {"reg_sqrt", {}},
str_buf << "\"reg_sqrt\": [], "; {"alpha", {}},
str_buf << "\"alpha\": [], "; {"fair_c", {}},
str_buf << "\"fair_c\": [], "; {"poisson_max_delta_step", {}},
str_buf << "\"poisson_max_delta_step\": [], "; {"tweedie_variance_power", {}},
str_buf << "\"tweedie_variance_power\": [], "; {"lambdarank_truncation_level", {}},
str_buf << "\"lambdarank_truncation_level\": [], "; {"lambdarank_norm", {}},
str_buf << "\"lambdarank_norm\": [], "; {"label_gain", {}},
str_buf << "\"label_gain\": [], "; {"metric", {"metrics", "metric_types"}},
str_buf << "\"metric\": [\"metrics\", \"metric_types\"], "; {"metric_freq", {"output_freq"}},
str_buf << "\"metric_freq\": [\"output_freq\"], "; {"is_provide_training_metric", {"training_metric", "is_training_metric", "train_metric"}},
str_buf << "\"is_provide_training_metric\": [\"training_metric\", \"is_training_metric\", \"train_metric\"], "; {"eval_at", {"ndcg_eval_at", "ndcg_at", "map_eval_at", "map_at"}},
str_buf << "\"eval_at\": [\"ndcg_eval_at\", \"ndcg_at\", \"map_eval_at\", \"map_at\"], "; {"multi_error_top_k", {}},
str_buf << "\"multi_error_top_k\": [], "; {"auc_mu_weights", {}},
str_buf << "\"auc_mu_weights\": [], "; {"num_machines", {"num_machine"}},
str_buf << "\"num_machines\": [\"num_machine\"], "; {"local_listen_port", {"local_port", "port"}},
str_buf << "\"local_listen_port\": [\"local_port\", \"port\"], "; {"time_out", {}},
str_buf << "\"time_out\": [], "; {"machine_list_filename", {"machine_list_file", "machine_list", "mlist"}},
str_buf << "\"machine_list_filename\": [\"machine_list_file\", \"machine_list\", \"mlist\"], "; {"machines", {"workers", "nodes"}},
str_buf << "\"machines\": [\"workers\", \"nodes\"], "; {"gpu_platform_id", {}},
str_buf << "\"gpu_platform_id\": [], "; {"gpu_device_id", {}},
str_buf << "\"gpu_device_id\": [], "; {"gpu_use_dp", {}},
str_buf << "\"gpu_use_dp\": [], "; {"num_gpu", {}},
str_buf << "\"num_gpu\": []"; });
str_buf << "}"; return map;
return str_buf.str();
} }
} // namespace LightGBM } // namespace LightGBM
...@@ -473,7 +473,8 @@ def test_choose_param_value(): ...@@ -473,7 +473,8 @@ def test_choose_param_value():
"local_listen_port": 1234, "local_listen_port": 1234,
"port": 2222, "port": 2222,
"metric": "auc", "metric": "auc",
"num_trees": 81 "num_trees": 81,
"n_iter": 13,
} }
# should resolve duplicate aliases, and prefer the main parameter # should resolve duplicate aliases, and prefer the main parameter
...@@ -485,15 +486,16 @@ def test_choose_param_value(): ...@@ -485,15 +486,16 @@ def test_choose_param_value():
assert params["local_listen_port"] == 1234 assert params["local_listen_port"] == 1234
assert "port" not in params assert "port" not in params
# should choose a value from an alias and set that value on main param # should choose the highest priority alias and set that value on main param
# if only an alias is used # if only aliases are used
params = lgb.basic._choose_param_value( params = lgb.basic._choose_param_value(
main_param_name="num_iterations", main_param_name="num_iterations",
params=params, params=params,
default_value=17 default_value=17
) )
assert params["num_iterations"] == 81 assert params["num_iterations"] == 13
assert "num_trees" not in params assert "num_trees" not in params
assert "n_iter" not in params
# should use the default if main param and aliases are missing # should use the default if main param and aliases are missing
params = lgb.basic._choose_param_value( params = lgb.basic._choose_param_value(
...@@ -508,7 +510,8 @@ def test_choose_param_value(): ...@@ -508,7 +510,8 @@ def test_choose_param_value():
"local_listen_port": 1234, "local_listen_port": 1234,
"port": 2222, "port": 2222,
"metric": "auc", "metric": "auc",
"num_trees": 81 "num_trees": 81,
"n_iter": 13,
} }
assert original_params == expected_params assert original_params == expected_params
...@@ -644,10 +647,13 @@ def test_param_aliases(): ...@@ -644,10 +647,13 @@ def test_param_aliases():
aliases = lgb.basic._ConfigAliases.aliases aliases = lgb.basic._ConfigAliases.aliases
assert isinstance(aliases, dict) assert isinstance(aliases, dict)
assert len(aliases) > 100 assert len(aliases) > 100
assert all(isinstance(i, set) for i in aliases.values()) assert all(isinstance(i, list) for i in aliases.values())
assert all(len(i) >= 1 for i in aliases.values()) assert all(len(i) >= 1 for i in aliases.values())
assert all(k in v for k, v in aliases.items()) assert all(k in v for k, v in aliases.items())
assert lgb.basic._ConfigAliases.get('config', 'task') == {'config', 'config_file', 'task', 'task_type'} assert lgb.basic._ConfigAliases.get('config', 'task') == {'config', 'config_file', 'task', 'task_type'}
assert lgb.basic._ConfigAliases.get_sorted('min_data_in_leaf') == [
'min_data_in_leaf', 'min_data', 'min_samples_leaf', 'min_child_samples', 'min_data_per_leaf'
]
def _bad_gradients(preds, _): def _bad_gradients(preds, _):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment