"tests/vscode:/vscode.git/clone" did not exist on "d92d844447292e1b2ce8d57e5747b4e0fa233c09"
Commit 846c42ef authored by Guolin Ke's avatar Guolin Ke
Browse files

a better algorithm for finding bins

parent ed958eb2
......@@ -85,7 +85,7 @@ enum TaskType {
/*! \brief Config for input and output files */
struct IOConfig: public ConfigBase {
public:
int max_bin = 255;
int max_bin = 256;
int data_random_seed = 1;
std::string data_filename = "";
std::vector<std::string> valid_data_filenames;
......
......@@ -88,8 +88,7 @@ inline static const char* Atoi(const char* p, int* out) {
if (*p == '-') {
sign = -1;
++p;
}
else if (*p == '+') {
} else if (*p == '+') {
++p;
}
for (value = 0; *p >= '0' && *p <= '9'; ++p) {
......@@ -117,8 +116,7 @@ inline static const char* Atof(const char* p, float* out) {
if (*p == '-') {
sign = -1.0f;
++p;
}
else if (*p == '+') {
} else if (*p == '+') {
++p;
}
......@@ -171,15 +169,14 @@ inline static const char* Atof(const char* p, float* out) {
&& *(p + cnt) != ':') {
++cnt;
}
if(cnt > 0){
if (cnt > 0) {
std::string tmp_str(p, cnt);
std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), ::tolower);
if (tmp_str == std::string("na") || tmp_str == std::string("nan")) {
*out = 0.0f;
} else if( tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
} else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
*out = sign * static_cast<float>(1e38);
}
else {
} else {
Log::Fatal("Unknow token %s in data file", tmp_str.c_str());
}
p += cnt;
......@@ -356,6 +353,28 @@ inline void Softmax(std::vector<float>* p_rec) {
}
}
template<typename T1, typename T2>
inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) {
std::vector<std::pair<T1, T2>> arr;
for (size_t i = start; i < keys.size(); ++i) {
arr.emplace_back(keys[i], values[i]);
}
if (!is_reverse) {
std::sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
return a.first < b.first;
});
} else {
std::sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
return a.first > b.first;
});
}
for (size_t i = start; i < arr.size(); ++i) {
keys[i] = arr[i].first;
values[i] = arr[i].second;
}
}
} // namespace Common
} // namespace LightGBM
......
#include <LightGBM/utils/common.h>
#include <LightGBM/bin.h>
#include "dense_bin.hpp"
......@@ -39,23 +40,24 @@ BinMapper::~BinMapper() {
}
void BinMapper::FindBin(std::vector<float>* values, int max_bin) {
std::vector<float>& ref_values = (*values);
size_t sample_size = values->size();
// find distinct_values first
float* distinct_values = new float[sample_size];
int *counts = new int[sample_size];
int num_values = 1;
std::sort(values->begin(), values->end());
distinct_values[0] = (*values)[0];
counts[0] = 1;
for (size_t i = 1; i < values->size(); ++i) {
if ((*values)[i] != (*values)[i - 1]) {
distinct_values[num_values] = (*values)[i];
counts[num_values] = 1;
++num_values;
std::vector<float> distinct_values;
std::vector<int> counts;
std::sort(ref_values.begin(), ref_values.end());
distinct_values.push_back(ref_values[0]);
counts.push_back(1);
for (size_t i = 1; i < ref_values.size(); ++i) {
if (ref_values[i] != ref_values[i - 1]) {
distinct_values.push_back(ref_values[i]);
counts.push_back(1);
} else {
++counts[num_values - 1];
++counts.back();
}
}
int num_values = static_cast<int>(distinct_values.size());
int cnt_in_bin0 = 0;
if (num_values <= max_bin) {
......@@ -68,54 +70,60 @@ void BinMapper::FindBin(std::vector<float>* values, int max_bin) {
cnt_in_bin0 = counts[0];
bin_upper_bound_[num_values - 1] = std::numeric_limits<float>::infinity();
} else {
// need find bins
num_bin_ = max_bin;
bin_upper_bound_ = new float[max_bin];
float * bin_lower_bound = new float[max_bin];
// mean size for one bin
float mean_bin_size = sample_size / static_cast<float>(max_bin);
int rest_sample_cnt = static_cast<int>(sample_size);
int cur_cnt_inbin = 0;
int bin_cnt = 0;
bin_lower_bound[0] = distinct_values[0];
for (int i = 0; i < num_values - 1; ++i) {
num_bin_ = max_bin;
std::vector<float> upper_bounds(max_bin, std::numeric_limits<float>::infinity());
std::vector<float> lower_bounds(max_bin, std::numeric_limits<float>::infinity());
// sort by count, descent
Common::SortForPair(counts, distinct_values, 0, true);
// fetch big slot as unique bin
while (counts[bin_cnt] > mean_bin_size) {
upper_bounds[bin_cnt] = distinct_values[bin_cnt];
lower_bounds[bin_cnt] = distinct_values[bin_cnt];
rest_sample_cnt -= counts[bin_cnt];
++bin_cnt;
}
// process reminder bins
if (bin_cnt < max_bin) {
// sort rest by values
Common::SortForPair<float, int>(distinct_values, counts, bin_cnt, false);
mean_bin_size = rest_sample_cnt / static_cast<float>(max_bin - bin_cnt);
lower_bounds[bin_cnt] = distinct_values[bin_cnt];
int cur_cnt_inbin = 0;
for (int i = bin_cnt; i < num_values - 1; ++i) {
rest_sample_cnt -= counts[i];
cur_cnt_inbin += counts[i];
// need a new bin
if (cur_cnt_inbin >= mean_bin_size) {
bin_upper_bound_[bin_cnt] = distinct_values[i];
upper_bounds[bin_cnt] = distinct_values[i];
if (bin_cnt == 0) { cnt_in_bin0 = cur_cnt_inbin; }
++bin_cnt;
bin_lower_bound[bin_cnt] = distinct_values[i + 1];
lower_bounds[bin_cnt] = distinct_values[i + 1];
if (bin_cnt >= max_bin - 1) break;
cur_cnt_inbin = 0;
mean_bin_size = rest_sample_cnt / static_cast<float>(max_bin - bin_cnt);
}
}
cur_cnt_inbin += counts[num_values - 1];
}
Common::SortForPair<float, float>(lower_bounds, upper_bounds, 0, false);
// update bin upper bound
for (int i = 0; i < bin_cnt; ++i) {
bin_upper_bound_[i] = (bin_upper_bound_[i] + bin_lower_bound[i + 1]) / 2.0f;
bin_upper_bound_ = new float[bin_cnt];
for (int i = 0; i < bin_cnt - 1; ++i) {
bin_upper_bound_[i] = (upper_bounds[i] + lower_bounds[i + 1]) / 2.0f;
}
// last bin upper bound
bin_upper_bound_[bin_cnt] = std::numeric_limits<float>::infinity();
++bin_cnt;
delete[] bin_lower_bound;
// if no so much bin
if (bin_cnt < max_bin) {
// old bin data
float* tmp_bin_upper_bound = bin_upper_bound_;
num_bin_ = bin_cnt;
bin_upper_bound_ = new float[num_bin_];
// copy back
for (int i = 0; i < num_bin_; ++i) {
bin_upper_bound_[i] = tmp_bin_upper_bound[i];
}
// free old space
delete[] tmp_bin_upper_bound;
}
bin_upper_bound_[bin_cnt - 1] = std::numeric_limits<float>::infinity();
CHECK(bin_cnt <= max_bin);
}
delete[] distinct_values;
delete[] counts;
// check trival(num_bin_ == 1) feature
if (num_bin_ <= 1) {
is_trival_ = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment