Commit db9ec217 authored by Guolin Ke's avatar Guolin Ke
Browse files

reduce parameters in categorical split

parent 8dc2bed8
...@@ -24,7 +24,7 @@ Categorical Feature Support ...@@ -24,7 +24,7 @@ Categorical Feature Support
- Converting to ``int`` type is needed first, and there is support for non-negative numbers only. - Converting to ``int`` type is needed first, and there is support for non-negative numbers only.
It is better to convert into continues ranges. It is better to convert into continues ranges.
- Use ``min_data_per_group``, ``cat_smooth_ratio`` to deal with over-fitting - Use ``min_data_per_group``, ``cat_smooth`` to deal with over-fitting
(when ``#data`` is small or ``#category`` is large). (when ``#data`` is small or ``#category`` is large).
- For categorical features with high cardinality (``#category`` is large), it is better to convert it to numerical features. - For categorical features with high cardinality (``#category`` is large), it is better to convert it to numerical features.
......
...@@ -253,33 +253,19 @@ Learning Control Parameters ...@@ -253,33 +253,19 @@ Learning Control Parameters
- min number of data per categorical group - min number of data per categorical group
- ``max_cat_threshold``, default=\ ``128``, type=int - ``max_cat_threshold``, default=\ ``32``, type=int
- use for the categorical features - use for the categorical features
- limit the max threshold points in categorical features - limit the max threshold points in categorical features
- ``min_cat_smooth``, default=\ ``5``, type=double - ``cat_smooth``, default=\ ``10``, type=double
- use for the categorical features
- refer to the descrption of the paramater ``cat_smooth_ratio``
- ``max_cat_smooth``, default=\ ``50``, type=double
- use for the categorical features
- refer to the descrption of the paramater ``cat_smooth_ratio``
- ``cat_smooth_ratio``, default=\ ``0.01``, type=double
- use for the categorical features - use for the categorical features
- this can reduce the effect of noises in categorical features, especially for categories with few data - this can reduce the effect of noises in categorical features, especially for categories with few data
- the smooth denominator is ``a = min(max_cat_smooth, max(min_cat_smooth, num_data / num_category * cat_smooth_ratio))`` - ``cat_l2``, default=\ ``10``, type=double
- ``cat_l2``, default=\ ``1``, type=double
- L2 regularization in categorcial split - L2 regularization in categorcial split
...@@ -294,7 +280,7 @@ IO Parameters ...@@ -294,7 +280,7 @@ IO Parameters
- LightGBM will auto compress memory according ``max_bin``. - LightGBM will auto compress memory according ``max_bin``.
For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255`` For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255``
- ``min_data_in_bin``, default=\ ``5``, type=int - ``min_data_in_bin``, default=\ ``3``, type=int
- min number of data inside one bin, use this to avoid one-data-one-bin (may over-fitting) - min number of data inside one bin, use this to avoid one-data-one-bin (may over-fitting)
......
...@@ -122,7 +122,7 @@ public: ...@@ -122,7 +122,7 @@ public:
bool is_predict_contrib = false; bool is_predict_contrib = false;
bool is_predict_raw_score = false; bool is_predict_raw_score = false;
int min_data_in_leaf = 20; int min_data_in_leaf = 20;
int min_data_in_bin = 5; int min_data_in_bin = 3;
double max_conflict_rate = 0.0f; double max_conflict_rate = 0.0f;
bool enable_bundle = true; bool enable_bundle = true;
bool has_header = false; bool has_header = false;
...@@ -226,11 +226,9 @@ public: ...@@ -226,11 +226,9 @@ public:
/*! \brief Set to true to use double precision math on GPU (default using single precision) */ /*! \brief Set to true to use double precision math on GPU (default using single precision) */
bool gpu_use_dp = false; bool gpu_use_dp = false;
int min_data_per_group = 100; int min_data_per_group = 100;
int max_cat_threshold = 128; int max_cat_threshold = 32;
double cat_smooth_ratio = 0.01; double cat_l2 = 10;
double cat_l2 = 1; double cat_smooth = 10;
double min_cat_smooth = 5;
double max_cat_smooth = 50;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
...@@ -473,7 +471,7 @@ struct ParameterAlias { ...@@ -473,7 +471,7 @@ struct ParameterAlias {
"max_conflict_rate", "poisson_max_delta_step", "gaussian_eta", "max_conflict_rate", "poisson_max_delta_step", "gaussian_eta",
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "machines", "histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "machines",
"zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib", "zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib",
"max_cat_threshold", "cat_smooth_ratio", "min_cat_smooth", "max_cat_smooth", "min_data_per_group", "cat_l2" "max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2"
}); });
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) { for (const auto& pair : *params) {
......
...@@ -329,6 +329,9 @@ namespace LightGBM { ...@@ -329,6 +329,9 @@ namespace LightGBM {
cnt_in_bin.clear(); cnt_in_bin.clear();
while (cur_cat < distinct_values_int.size() while (cur_cat < distinct_values_int.size()
&& (used_cnt < cut_cnt || num_bin_ < max_bin)) { && (used_cnt < cut_cnt || num_bin_ < max_bin)) {
if (counts_int[cur_cat] < min_data_in_bin && cur_cat > 1) {
break;
}
bin_2_categorical_.push_back(distinct_values_int[cur_cat]); bin_2_categorical_.push_back(distinct_values_int[cur_cat]);
categorical_2_bin_[distinct_values_int[cur_cat]] = static_cast<unsigned int>(num_bin_); categorical_2_bin_[distinct_values_int[cur_cat]] = static_cast<unsigned int>(num_bin_);
used_cnt += counts_int[cur_cat]; used_cnt += counts_int[cur_cat];
......
...@@ -381,16 +381,12 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) ...@@ -381,16 +381,12 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetInt(params, "gpu_device_id", &gpu_device_id); GetInt(params, "gpu_device_id", &gpu_device_id);
GetBool(params, "gpu_use_dp", &gpu_use_dp); GetBool(params, "gpu_use_dp", &gpu_use_dp);
GetInt(params, "max_cat_threshold", &max_cat_threshold); GetInt(params, "max_cat_threshold", &max_cat_threshold);
GetDouble(params, "cat_smooth_ratio", &cat_smooth_ratio);
GetDouble(params, "cat_l2", &cat_l2); GetDouble(params, "cat_l2", &cat_l2);
GetDouble(params, "min_cat_smooth", &min_cat_smooth); GetDouble(params, "cat_smooth", &cat_smooth);
GetDouble(params, "max_cat_smooth", &max_cat_smooth);
GetInt(params, "min_data_per_group", &min_data_per_group); GetInt(params, "min_data_per_group", &min_data_per_group);
CHECK(max_cat_threshold > 0); CHECK(max_cat_threshold > 0);
CHECK(cat_smooth_ratio >= 0);
CHECK(cat_l2 >= 0.0f); CHECK(cat_l2 >= 0.0f);
CHECK(min_cat_smooth >= 1); CHECK(cat_smooth >= 1);
CHECK(max_cat_smooth > min_cat_smooth);
CHECK(min_data_per_group > 0); CHECK(min_data_per_group > 0);
} }
......
...@@ -116,13 +116,9 @@ public: ...@@ -116,13 +116,9 @@ public:
if (is_full_categorical) ++used_bin; if (is_full_categorical) ++used_bin;
const double smooth_hess = std::max(meta_->tree_config->min_cat_smooth,
std::min(meta_->tree_config->cat_smooth_ratio * num_data, meta_->tree_config->max_cat_smooth));
const int min_data_per_cat = static_cast<int>(smooth_hess);
std::vector<int> sorted_idx; std::vector<int> sorted_idx;
for (int i = 0; i < used_bin; ++i) { for (int i = 0; i < used_bin; ++i) {
if (data_[i].cnt >= min_data_per_cat) { if (data_[i].cnt >= meta_->tree_config->cat_smooth) {
sorted_idx.push_back(i); sorted_idx.push_back(i);
} }
} }
...@@ -130,12 +126,12 @@ public: ...@@ -130,12 +126,12 @@ public:
const double l2 = meta_->tree_config->lambda_l2 + meta_->tree_config->cat_l2; const double l2 = meta_->tree_config->lambda_l2 + meta_->tree_config->cat_l2;
auto ctr_fun = [&smooth_hess](double sum_grad, double sum_hess) { auto ctr_fun = [this](double sum_grad, double sum_hess) {
return (sum_grad) / (sum_hess + smooth_hess); return (sum_grad) / (sum_hess + meta_->tree_config->cat_smooth);
}; };
std::sort(sorted_idx.begin(), sorted_idx.end(), std::sort(sorted_idx.begin(), sorted_idx.end(),
[this, &ctr_fun](int i, int j) { [this, &ctr_fun](int i, int j) {
return ctr_fun(data_[i].sum_gradients, data_[i].cnt) < ctr_fun(data_[j].sum_gradients, data_[j].cnt); return ctr_fun(data_[i].sum_gradients, data_[i].sum_hessians) < ctr_fun(data_[j].sum_gradients, data_[j].sum_hessians);
}); });
std::vector<int> find_direction(1, 1); std::vector<int> find_direction(1, 1);
......
...@@ -229,7 +229,7 @@ class TestEngine(unittest.TestCase): ...@@ -229,7 +229,7 @@ class TestEngine(unittest.TestCase):
'learning_rate': 1, 'learning_rate': 1,
'min_data_in_bin': 1, 'min_data_in_bin': 1,
'min_data_per_group': 1, 'min_data_per_group': 1,
'min_cat_smooth': 1, 'cat_smooth': 1,
'cat_l2': 0, 'cat_l2': 0,
'zero_as_missing': True, 'zero_as_missing': True,
'categorical_column': 0 'categorical_column': 0
...@@ -262,7 +262,7 @@ class TestEngine(unittest.TestCase): ...@@ -262,7 +262,7 @@ class TestEngine(unittest.TestCase):
'learning_rate': 1, 'learning_rate': 1,
'min_data_in_bin': 1, 'min_data_in_bin': 1,
'min_data_per_group': 1, 'min_data_per_group': 1,
'min_cat_smooth': 1, 'cat_smooth': 1,
'cat_l2': 0, 'cat_l2': 0,
'zero_as_missing': False, 'zero_as_missing': False,
'categorical_column': 0 'categorical_column': 0
...@@ -295,7 +295,7 @@ class TestEngine(unittest.TestCase): ...@@ -295,7 +295,7 @@ class TestEngine(unittest.TestCase):
'learning_rate': 1, 'learning_rate': 1,
'min_data_in_bin': 1, 'min_data_in_bin': 1,
'min_data_per_group': 1, 'min_data_per_group': 1,
'min_cat_smooth': 1, 'cat_smooth': 1,
'cat_l2': 0, 'cat_l2': 0,
'zero_as_missing': False, 'zero_as_missing': False,
'categorical_column': 0 'categorical_column': 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment