Commit 42710827 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix `max_drop`. add many checks for parameters.

parent 0af44ac8
......@@ -239,7 +239,6 @@ public:
struct BoostingConfig: public ConfigBase {
public:
virtual ~BoostingConfig() {}
double sigmoid = 1.0f;
int output_freq = 1;
bool is_provide_training_metric = false;
int num_iterations = 100;
......
......@@ -96,6 +96,9 @@ private:
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) {
break;
}
}
}
} else {
......@@ -105,6 +108,9 @@ private:
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate) {
drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) {
break;
}
}
}
}
......
......@@ -251,12 +251,14 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "max_bin", &max_bin);
CHECK(max_bin > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetInt(params, "data_random_seed", &data_random_seed);
GetString(params, "data", &data_filename);
GetString(params, "init_score_file", &initscore_filename);
GetInt(params, "verbose", &verbosity);
GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt > 0);
GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetDouble(params, "sparse_threshold", &sparse_threshold);
......@@ -290,9 +292,10 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin > 0);
CHECK(min_data_in_leaf >= 0);
GetDouble(params, "max_conflict_rate", &max_conflict_rate);
CHECK(max_conflict_rate >= 0);
GetBool(params, "enable_bundle", &enable_bundle);
GetBool(params, "pred_early_stop", &pred_early_stop);
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
......@@ -304,15 +307,21 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "is_unbalance", &is_unbalance);
GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid > 0);
GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c > 0);
GetDouble(params, "gaussian_eta", &gaussian_eta);
CHECK(gaussian_eta > 0);
GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
CHECK(poisson_max_delta_step > 0);
GetInt(params, "max_position", &max_position);
CHECK(max_position > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
CHECK(num_class > 0);
GetDouble(params, "scale_pos_weight", &scale_pos_weight);
CHECK(scale_pos_weight > 0);
GetDouble(params, "alpha", &alpha);
CHECK(alpha > 0 && alpha < 1);
GetBool(params, "reg_sqrt", &reg_sqrt);
std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) {
......@@ -331,9 +340,13 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid > 0);
GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetDouble(params, "alpha", &alpha);
CHECK(alpha > 0 && alpha < 1);
std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToArray<double>(tmp_str, ',');
......@@ -365,7 +378,8 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 0 || min_data_in_leaf > 0);
CHECK(min_data_in_leaf > 0);
CHECK(min_sum_hessian_in_leaf >= 0);
GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >= 0.0f);
GetDouble(params, "lambda_l2", &lambda_l2);
......@@ -380,6 +394,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
GetInt(params, "top_k", &top_k);
CHECK(top_k > 0);
GetInt(params, "gpu_platform_id", &gpu_platform_id);
GetInt(params, "gpu_device_id", &gpu_device_id);
GetBool(params, "gpu_use_dp", &gpu_use_dp);
......@@ -397,7 +412,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_iterations", &num_iterations);
GetDouble(params, "sigmoid", &sigmoid);
CHECK(num_iterations >= 0);
GetInt(params, "bagging_seed", &bagging_seed);
GetInt(params, "bagging_freq", &bagging_freq);
......@@ -412,17 +426,22 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetInt(params, "drop_seed", &drop_seed);
GetDouble(params, "drop_rate", &drop_rate);
GetDouble(params, "skip_drop", &skip_drop);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetInt(params, "max_drop", &max_drop);
CHECK(max_drop > 0);
GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
GetBool(params, "uniform_drop", &uniform_drop);
GetDouble(params, "top_rate", &top_rate);
GetDouble(params, "other_rate", &other_rate);
CHECK(top_rate > 0);
CHECK(top_rate > 0);
CHECK(top_rate + top_rate <= 1.0);
GetBool(params, "boost_from_average", &boost_from_average);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner_type);
tree_config.Set(params);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment