Commit c4c83bc7 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in finding best split for categorical feature.

parent ef778069
...@@ -131,12 +131,12 @@ public: ...@@ -131,12 +131,12 @@ public:
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian); output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient, output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian); sum_hessian - best_sum_left_hessian);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian; output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
output->gain = best_gain - gain_shift; output->gain = best_gain - gain_shift;
} else { } else {
output->feature = meta_->feature_idx; output->feature = meta_->feature_idx;
...@@ -148,6 +148,9 @@ public: ...@@ -148,6 +148,9 @@ public:
SplitInfo* output) { SplitInfo* output) {
double best_gain = kMinScore; double best_gain = kMinScore;
uint32_t best_threshold = static_cast<uint32_t>(meta_->num_bin); uint32_t best_threshold = static_cast<uint32_t>(meta_->num_bin);
data_size_t best_left_count = 0;
double best_sum_left_gradient = 0.0f;
double best_sum_left_hessian = 0.0f;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian); double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
is_splittable_ = false; is_splittable_ = false;
...@@ -179,6 +182,9 @@ public: ...@@ -179,6 +182,9 @@ public:
// better split point // better split point
if (current_gain > best_gain) { if (current_gain > best_gain) {
best_threshold = static_cast<uint32_t>(t + bias); best_threshold = static_cast<uint32_t>(t + bias);
best_sum_left_gradient = data_[t].sum_gradients;
best_sum_left_hessian = data_[t].sum_hessians + kEpsilon;
best_left_count = data_[t].cnt;
best_gain = current_gain; best_gain = current_gain;
} }
} }
...@@ -186,7 +192,7 @@ public: ...@@ -186,7 +192,7 @@ public:
if (bias == 1) { if (bias == 1) {
t = meta_->num_bin - 1 - bias; t = meta_->num_bin - 1 - bias;
double sum_bin0_gradient = sum_gradient; double sum_bin0_gradient = sum_gradient;
double sum_bin0_hessian = sum_hessian; double sum_bin0_hessian = sum_hessian - 2 * kEpsilon;
data_size_t cnt_bin0 = num_data; data_size_t cnt_bin0 = num_data;
for (; t >= 0; --t) { for (; t >= 0; --t) {
sum_bin0_gradient -= data_[t].sum_gradients; sum_bin0_gradient -= data_[t].sum_gradients;
...@@ -207,6 +213,9 @@ public: ...@@ -207,6 +213,9 @@ public:
// better split point // better split point
if (current_gain > best_gain) { if (current_gain > best_gain) {
best_threshold = static_cast<uint32_t>(0); best_threshold = static_cast<uint32_t>(0);
best_sum_left_gradient = sum_bin0_gradient;
best_sum_left_hessian = sum_bin0_hessian + kEpsilon;
best_left_count = cnt_bin0;
best_gain = current_gain; best_gain = current_gain;
} }
} }
...@@ -216,17 +225,15 @@ public: ...@@ -216,17 +225,15 @@ public:
// update split information // update split information
output->feature = meta_->feature_idx; output->feature = meta_->feature_idx;
output->threshold = best_threshold; output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(data_[best_threshold].sum_gradients, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian);
data_[best_threshold].sum_hessians + kEpsilon); output->left_count = best_left_count;
output->left_count = data_[best_threshold].cnt; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_gradient = data_[best_threshold].sum_gradients; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->left_sum_hessian = data_[best_threshold].sum_hessians + kEpsilon; output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian);
output->right_output = CalculateSplittedLeafOutput(sum_gradient - data_[best_threshold].sum_gradients, output->right_count = num_data - best_left_count;
sum_hessian - data_[best_threshold].sum_hessians - kEpsilon); output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_count = num_data - data_[best_threshold].cnt; output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
output->right_sum_gradient = sum_gradient - data_[best_threshold].sum_gradients;
output->right_sum_hessian = sum_hessian - data_[best_threshold].sum_hessians - kEpsilon;
output->gain = best_gain - gain_shift; output->gain = best_gain - gain_shift;
} else { } else {
output->feature = meta_->feature_idx; output->feature = meta_->feature_idx;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment