Commit 42710827 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix `max_drop`. add many checks for parameters.

parent 0af44ac8
...@@ -239,7 +239,6 @@ public: ...@@ -239,7 +239,6 @@ public:
struct BoostingConfig: public ConfigBase { struct BoostingConfig: public ConfigBase {
public: public:
virtual ~BoostingConfig() {} virtual ~BoostingConfig() {}
double sigmoid = 1.0f;
int output_freq = 1; int output_freq = 1;
bool is_provide_training_metric = false; bool is_provide_training_metric = false;
int num_iterations = 100; int num_iterations = 100;
......
...@@ -96,6 +96,9 @@ private: ...@@ -96,6 +96,9 @@ private:
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) { if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
drop_index_.push_back(num_init_iteration_ + i); drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) {
break;
}
} }
} }
} else { } else {
...@@ -105,6 +108,9 @@ private: ...@@ -105,6 +108,9 @@ private:
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate) { if (random_for_drop_.NextFloat() < drop_rate) {
drop_index_.push_back(num_init_iteration_ + i); drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) {
break;
}
} }
} }
} }
......
...@@ -251,12 +251,14 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -251,12 +251,14 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "max_bin", &max_bin); GetInt(params, "max_bin", &max_bin);
CHECK(max_bin > 0); CHECK(max_bin > 0);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetInt(params, "data_random_seed", &data_random_seed); GetInt(params, "data_random_seed", &data_random_seed);
GetString(params, "data", &data_filename); GetString(params, "data", &data_filename);
GetString(params, "init_score_file", &initscore_filename); GetString(params, "init_score_file", &initscore_filename);
GetInt(params, "verbose", &verbosity); GetInt(params, "verbose", &verbosity);
GetInt(params, "num_iteration_predict", &num_iteration_predict); GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt); GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt > 0);
GetBool(params, "is_pre_partition", &is_pre_partition); GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse); GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetDouble(params, "sparse_threshold", &sparse_threshold); GetDouble(params, "sparse_threshold", &sparse_threshold);
...@@ -290,9 +292,10 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -290,9 +292,10 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf); GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetInt(params, "min_data_in_bin", &min_data_in_bin); GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin > 0); CHECK(min_data_in_bin > 0);
CHECK(min_data_in_leaf >= 0);
GetDouble(params, "max_conflict_rate", &max_conflict_rate); GetDouble(params, "max_conflict_rate", &max_conflict_rate);
CHECK(max_conflict_rate >= 0);
GetBool(params, "enable_bundle", &enable_bundle); GetBool(params, "enable_bundle", &enable_bundle);
GetBool(params, "pred_early_stop", &pred_early_stop); GetBool(params, "pred_early_stop", &pred_early_stop);
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq); GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin); GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
...@@ -304,15 +307,21 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -304,15 +307,21 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) { void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "is_unbalance", &is_unbalance); GetBool(params, "is_unbalance", &is_unbalance);
GetDouble(params, "sigmoid", &sigmoid); GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid > 0);
GetDouble(params, "fair_c", &fair_c); GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c > 0);
GetDouble(params, "gaussian_eta", &gaussian_eta); GetDouble(params, "gaussian_eta", &gaussian_eta);
CHECK(gaussian_eta > 0);
GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step); GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
CHECK(poisson_max_delta_step > 0);
GetInt(params, "max_position", &max_position); GetInt(params, "max_position", &max_position);
CHECK(max_position > 0); CHECK(max_position > 0);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1); CHECK(num_class > 0);
GetDouble(params, "scale_pos_weight", &scale_pos_weight); GetDouble(params, "scale_pos_weight", &scale_pos_weight);
CHECK(scale_pos_weight > 0);
GetDouble(params, "alpha", &alpha); GetDouble(params, "alpha", &alpha);
CHECK(alpha > 0 && alpha < 1);
GetBool(params, "reg_sqrt", &reg_sqrt); GetBool(params, "reg_sqrt", &reg_sqrt);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
...@@ -331,9 +340,13 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa ...@@ -331,9 +340,13 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) { void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "sigmoid", &sigmoid); GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid > 0);
GetDouble(params, "fair_c", &fair_c); GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c > 0);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetDouble(params, "alpha", &alpha); GetDouble(params, "alpha", &alpha);
CHECK(alpha > 0 && alpha < 1);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToArray<double>(tmp_str, ','); label_gain = Common::StringToArray<double>(tmp_str, ',');
...@@ -365,7 +378,8 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param ...@@ -365,7 +378,8 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) { void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf); GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf); GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 0 || min_data_in_leaf > 0); CHECK(min_data_in_leaf > 0);
CHECK(min_sum_hessian_in_leaf >= 0);
GetDouble(params, "lambda_l1", &lambda_l1); GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >= 0.0f); CHECK(lambda_l1 >= 0.0f);
GetDouble(params, "lambda_l2", &lambda_l2); GetDouble(params, "lambda_l2", &lambda_l2);
...@@ -380,6 +394,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) ...@@ -380,6 +394,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "histogram_pool_size", &histogram_pool_size); GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth); GetInt(params, "max_depth", &max_depth);
GetInt(params, "top_k", &top_k); GetInt(params, "top_k", &top_k);
CHECK(top_k > 0);
GetInt(params, "gpu_platform_id", &gpu_platform_id); GetInt(params, "gpu_platform_id", &gpu_platform_id);
GetInt(params, "gpu_device_id", &gpu_device_id); GetInt(params, "gpu_device_id", &gpu_device_id);
GetBool(params, "gpu_use_dp", &gpu_use_dp); GetBool(params, "gpu_use_dp", &gpu_use_dp);
...@@ -397,7 +412,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) ...@@ -397,7 +412,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) { void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_iterations", &num_iterations); GetInt(params, "num_iterations", &num_iterations);
GetDouble(params, "sigmoid", &sigmoid);
CHECK(num_iterations >= 0); CHECK(num_iterations >= 0);
GetInt(params, "bagging_seed", &bagging_seed); GetInt(params, "bagging_seed", &bagging_seed);
GetInt(params, "bagging_freq", &bagging_freq); GetInt(params, "bagging_freq", &bagging_freq);
...@@ -412,17 +426,22 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -412,17 +426,22 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK(output_freq >= 0); CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric); GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetInt(params, "drop_seed", &drop_seed); GetInt(params, "drop_seed", &drop_seed);
GetDouble(params, "drop_rate", &drop_rate); GetDouble(params, "drop_rate", &drop_rate);
GetDouble(params, "skip_drop", &skip_drop); GetDouble(params, "skip_drop", &skip_drop);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetInt(params, "max_drop", &max_drop); GetInt(params, "max_drop", &max_drop);
CHECK(max_drop > 0);
GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode); GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
GetBool(params, "uniform_drop", &uniform_drop); GetBool(params, "uniform_drop", &uniform_drop);
GetDouble(params, "top_rate", &top_rate); GetDouble(params, "top_rate", &top_rate);
GetDouble(params, "other_rate", &other_rate); GetDouble(params, "other_rate", &other_rate);
CHECK(top_rate > 0);
CHECK(top_rate > 0);
CHECK(top_rate + top_rate <= 1.0);
GetBool(params, "boost_from_average", &boost_from_average); GetBool(params, "boost_from_average", &boost_from_average);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetDeviceType(params, &device_type); GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner_type); GetTreeLearnerType(params, &tree_learner_type);
tree_config.Set(params); tree_config.Set(params);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment