"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b6deb9a857edc71da5f5f17a295043b83606da04"
Unverified Commit 20f94c52 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix the bug in bin with small values (#2342)

* fix the bug in bin with small values

* Update bin.cpp

* Update test_engine.py
parent 86c6a2d0
...@@ -181,6 +181,9 @@ namespace LightGBM { ...@@ -181,6 +181,9 @@ namespace LightGBM {
int left_max_bin = static_cast<int>(static_cast<double>(left_cnt_data) / (total_sample_cnt - cnt_zero) * (max_bin - 1)); int left_max_bin = static_cast<int>(static_cast<double>(left_cnt_data) / (total_sample_cnt - cnt_zero) * (max_bin - 1));
left_max_bin = std::max(1, left_max_bin); left_max_bin = std::max(1, left_max_bin);
bin_upper_bound = GreedyFindBin(distinct_values, counts, left_cnt, left_max_bin, left_cnt_data, min_data_in_bin); bin_upper_bound = GreedyFindBin(distinct_values, counts, left_cnt, left_max_bin, left_cnt_data, min_data_in_bin);
if (bin_upper_bound.size() > 0) {
bin_upper_bound.back() = -kZeroThreshold;
}
} }
int right_start = -1; int right_start = -1;
...@@ -191,32 +194,16 @@ namespace LightGBM { ...@@ -191,32 +194,16 @@ namespace LightGBM {
} }
} }
if (bin_upper_bound.size() == 0) { int right_max_bin = max_bin - 1 - static_cast<int>(bin_upper_bound.size());
if (max_bin > 2) { if (right_start >= 0 && right_max_bin > 0) {
// create zero bin
bin_upper_bound.push_back(-kZeroThreshold);
bin_upper_bound.push_back(kZeroThreshold);
}
else if (max_bin > 1) {
bin_upper_bound.push_back(kZeroThreshold);
}
} else {
bin_upper_bound.back() = -kZeroThreshold;
if (max_bin > 2) {
// create zero bin
bin_upper_bound.push_back(kZeroThreshold);
}
}
int right_max_bin = max_bin - static_cast<int>(bin_upper_bound.size());
if ((right_start >= 0) && (right_max_bin > 0)) {
auto right_bounds = GreedyFindBin(distinct_values + right_start, counts + right_start, auto right_bounds = GreedyFindBin(distinct_values + right_start, counts + right_start,
num_distinct_values - right_start, right_max_bin, right_cnt_data, min_data_in_bin); num_distinct_values - right_start, right_max_bin, right_cnt_data, min_data_in_bin);
bin_upper_bound.push_back(kZeroThreshold);
bin_upper_bound.insert(bin_upper_bound.end(), right_bounds.begin(), right_bounds.end()); bin_upper_bound.insert(bin_upper_bound.end(), right_bounds.begin(), right_bounds.end());
} else { } else {
bin_upper_bound.push_back(std::numeric_limits<double>::infinity()); bin_upper_bound.push_back(std::numeric_limits<double>::infinity());
} }
CHECK(bin_upper_bound.size() <= max_bin); CHECK(bin_upper_bound.size() <= static_cast<size_t>(max_bin));
return bin_upper_bound; return bin_upper_bound;
} }
......
...@@ -921,7 +921,7 @@ class TestEngine(unittest.TestCase): ...@@ -921,7 +921,7 @@ class TestEngine(unittest.TestCase):
} }
lgb_data = lgb.Dataset(X, label=y) lgb_data = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_data, num_boost_round=1) est = lgb.train(params, lgb_data, num_boost_round=1)
self.assertEqual(len(np.unique(est.predict(X))), 99) self.assertEqual(len(np.unique(est.predict(X))), 100)
params['max_bin_by_feature'] = [2, 100] params['max_bin_by_feature'] = [2, 100]
lgb_data = lgb.Dataset(X, label=y) lgb_data = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_data, num_boost_round=1) est = lgb.train(params, lgb_data, num_boost_round=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment