"src/vscode:/vscode.git/clone" did not exist on "7ed1ed3eb65d99bbcfd3273a99f4ceb9bdfed890"
Commit d1e0bab5 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix some bugs in bin construction.

parent 0f4ea846
...@@ -67,8 +67,8 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in ...@@ -67,8 +67,8 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
} }
int num_values = static_cast<int>(distinct_values.size()); int num_values = static_cast<int>(distinct_values.size());
int cnt_in_bin0 = 0; int cnt_in_bin0 = 0;
if (num_values <= max_bin) { if (num_values <= max_bin) {
std::sort(distinct_values.begin(), distinct_values.end());
// use distinct value is enough // use distinct value is enough
num_bin_ = num_values; num_bin_ = num_values;
bin_upper_bound_ = new double[num_values]; bin_upper_bound_ = new double[num_values];
...@@ -78,12 +78,11 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in ...@@ -78,12 +78,11 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
cnt_in_bin0 = counts[0]; cnt_in_bin0 = counts[0];
bin_upper_bound_[num_values - 1] = std::numeric_limits<double>::infinity(); bin_upper_bound_[num_values - 1] = std::numeric_limits<double>::infinity();
} else { } else {
double min_lower_bound = std::numeric_limits<double>::infinity();
// mean size for one bin // mean size for one bin
double mean_bin_size = sample_size / static_cast<double>(max_bin); double mean_bin_size = sample_size / static_cast<double>(max_bin);
int rest_sample_cnt = static_cast<int>(sample_size); int rest_sample_cnt = static_cast<int>(sample_size);
int bin_cnt = 0; int bin_cnt = 0;
num_bin_ = max_bin;
std::vector<double> upper_bounds(max_bin, std::numeric_limits<double>::infinity()); std::vector<double> upper_bounds(max_bin, std::numeric_limits<double>::infinity());
std::vector<double> lower_bounds(max_bin, std::numeric_limits<double>::infinity()); std::vector<double> lower_bounds(max_bin, std::numeric_limits<double>::infinity());
// sort by count, descent // sort by count, descent
...@@ -92,6 +91,10 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in ...@@ -92,6 +91,10 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
while (counts[bin_cnt] > mean_bin_size) { while (counts[bin_cnt] > mean_bin_size) {
upper_bounds[bin_cnt] = distinct_values[bin_cnt]; upper_bounds[bin_cnt] = distinct_values[bin_cnt];
lower_bounds[bin_cnt] = distinct_values[bin_cnt]; lower_bounds[bin_cnt] = distinct_values[bin_cnt];
if (lower_bounds[bin_cnt] < min_lower_bound) {
min_lower_bound = lower_bounds[bin_cnt];
cnt_in_bin0 = counts[bin_cnt];
}
rest_sample_cnt -= counts[bin_cnt]; rest_sample_cnt -= counts[bin_cnt];
++bin_cnt; ++bin_cnt;
} }
...@@ -108,7 +111,10 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in ...@@ -108,7 +111,10 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
// need a new bin // need a new bin
if (cur_cnt_inbin >= mean_bin_size) { if (cur_cnt_inbin >= mean_bin_size) {
upper_bounds[bin_cnt] = distinct_values[i]; upper_bounds[bin_cnt] = distinct_values[i];
if (bin_cnt == 0) { cnt_in_bin0 = cur_cnt_inbin; } if (lower_bounds[bin_cnt] < min_lower_bound) {
min_lower_bound = lower_bounds[bin_cnt];
cnt_in_bin0 = cur_cnt_inbin;
}
++bin_cnt; ++bin_cnt;
lower_bounds[bin_cnt] = distinct_values[i + 1]; lower_bounds[bin_cnt] = distinct_values[i + 1];
if (bin_cnt >= max_bin - 1) break; if (bin_cnt >= max_bin - 1) break;
...@@ -117,7 +123,6 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in ...@@ -117,7 +123,6 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
} }
} }
cur_cnt_inbin += counts[num_values - 1]; cur_cnt_inbin += counts[num_values - 1];
} }
Common::SortForPair<double, double>(lower_bounds, upper_bounds, 0, false); Common::SortForPair<double, double>(lower_bounds, upper_bounds, 0, false);
// update bin upper bound // update bin upper bound
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
namespace LightGBM { namespace LightGBM {
DatasetLoader::DatasetLoader(const IOConfig& io_config, const PredictFunction& predict_fun) DatasetLoader::DatasetLoader(const IOConfig& io_config, const PredictFunction& predict_fun)
:io_config_(io_config), predict_fun_(predict_fun){ :io_config_(io_config), predict_fun_(predict_fun), random_(io_config_.data_random_seed){
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment