Commit 846c42ef authored by Guolin Ke's avatar Guolin Ke
Browse files

a better algorithm for finding bins

parent ed958eb2
...@@ -85,7 +85,7 @@ enum TaskType { ...@@ -85,7 +85,7 @@ enum TaskType {
/*! \brief Config for input and output files */ /*! \brief Config for input and output files */
struct IOConfig: public ConfigBase { struct IOConfig: public ConfigBase {
public: public:
int max_bin = 255; int max_bin = 256;
int data_random_seed = 1; int data_random_seed = 1;
std::string data_filename = ""; std::string data_filename = "";
std::vector<std::string> valid_data_filenames; std::vector<std::string> valid_data_filenames;
......
...@@ -88,8 +88,7 @@ inline static const char* Atoi(const char* p, int* out) { ...@@ -88,8 +88,7 @@ inline static const char* Atoi(const char* p, int* out) {
if (*p == '-') { if (*p == '-') {
sign = -1; sign = -1;
++p; ++p;
} } else if (*p == '+') {
else if (*p == '+') {
++p; ++p;
} }
for (value = 0; *p >= '0' && *p <= '9'; ++p) { for (value = 0; *p >= '0' && *p <= '9'; ++p) {
...@@ -117,8 +116,7 @@ inline static const char* Atof(const char* p, float* out) { ...@@ -117,8 +116,7 @@ inline static const char* Atof(const char* p, float* out) {
if (*p == '-') { if (*p == '-') {
sign = -1.0f; sign = -1.0f;
++p; ++p;
} } else if (*p == '+') {
else if (*p == '+') {
++p; ++p;
} }
...@@ -165,21 +163,20 @@ inline static const char* Atof(const char* p, float* out) { ...@@ -165,21 +163,20 @@ inline static const char* Atof(const char* p, float* out) {
*out = sign * (frac ? (value / scale) : (value * scale)); *out = sign * (frac ? (value / scale) : (value * scale));
} else { } else {
size_t cnt = 0; size_t cnt = 0;
while (*(p + cnt) != '\0' && *(p + cnt) != ' ' while (*(p + cnt) != '\0' && *(p + cnt) != ' '
&& *(p + cnt) != '\t' && *(p + cnt) != ',' && *(p + cnt) != '\t' && *(p + cnt) != ','
&& *(p + cnt) != '\n' && *(p + cnt) != '\r' && *(p + cnt) != '\n' && *(p + cnt) != '\r'
&& *(p + cnt) != ':') { && *(p + cnt) != ':') {
++cnt; ++cnt;
} }
if(cnt > 0){ if (cnt > 0) {
std::string tmp_str(p, cnt); std::string tmp_str(p, cnt);
std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), ::tolower); std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), ::tolower);
if (tmp_str == std::string("na") || tmp_str == std::string("nan")) { if (tmp_str == std::string("na") || tmp_str == std::string("nan")) {
*out = 0.0f; *out = 0.0f;
} else if( tmp_str == std::string("inf") || tmp_str == std::string("infinity")) { } else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
*out = sign * static_cast<float>(1e38); *out = sign * static_cast<float>(1e38);
} } else {
else {
Log::Fatal("Unknow token %s in data file", tmp_str.c_str()); Log::Fatal("Unknow token %s in data file", tmp_str.c_str());
} }
p += cnt; p += cnt;
...@@ -356,6 +353,28 @@ inline void Softmax(std::vector<float>* p_rec) { ...@@ -356,6 +353,28 @@ inline void Softmax(std::vector<float>* p_rec) {
} }
} }
template<typename T1, typename T2>
inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) {
std::vector<std::pair<T1, T2>> arr;
for (size_t i = start; i < keys.size(); ++i) {
arr.emplace_back(keys[i], values[i]);
}
if (!is_reverse) {
std::sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
return a.first < b.first;
});
} else {
std::sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
return a.first > b.first;
});
}
for (size_t i = start; i < arr.size(); ++i) {
keys[i] = arr[i].first;
values[i] = arr[i].second;
}
}
} // namespace Common } // namespace Common
} // namespace LightGBM } // namespace LightGBM
......
#include <LightGBM/utils/common.h>
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include "dense_bin.hpp" #include "dense_bin.hpp"
...@@ -39,23 +40,24 @@ BinMapper::~BinMapper() { ...@@ -39,23 +40,24 @@ BinMapper::~BinMapper() {
} }
void BinMapper::FindBin(std::vector<float>* values, int max_bin) { void BinMapper::FindBin(std::vector<float>* values, int max_bin) {
std::vector<float>& ref_values = (*values);
size_t sample_size = values->size(); size_t sample_size = values->size();
// find distinct_values first // find distinct_values first
float* distinct_values = new float[sample_size]; std::vector<float> distinct_values;
int *counts = new int[sample_size]; std::vector<int> counts;
int num_values = 1;
std::sort(values->begin(), values->end()); std::sort(ref_values.begin(), ref_values.end());
distinct_values[0] = (*values)[0]; distinct_values.push_back(ref_values[0]);
counts[0] = 1; counts.push_back(1);
for (size_t i = 1; i < values->size(); ++i) { for (size_t i = 1; i < ref_values.size(); ++i) {
if ((*values)[i] != (*values)[i - 1]) { if (ref_values[i] != ref_values[i - 1]) {
distinct_values[num_values] = (*values)[i]; distinct_values.push_back(ref_values[i]);
counts[num_values] = 1; counts.push_back(1);
++num_values;
} else { } else {
++counts[num_values - 1]; ++counts.back();
} }
} }
int num_values = static_cast<int>(distinct_values.size());
int cnt_in_bin0 = 0; int cnt_in_bin0 = 0;
if (num_values <= max_bin) { if (num_values <= max_bin) {
...@@ -68,54 +70,60 @@ void BinMapper::FindBin(std::vector<float>* values, int max_bin) { ...@@ -68,54 +70,60 @@ void BinMapper::FindBin(std::vector<float>* values, int max_bin) {
cnt_in_bin0 = counts[0]; cnt_in_bin0 = counts[0];
bin_upper_bound_[num_values - 1] = std::numeric_limits<float>::infinity(); bin_upper_bound_[num_values - 1] = std::numeric_limits<float>::infinity();
} else { } else {
// need find bins
num_bin_ = max_bin;
bin_upper_bound_ = new float[max_bin];
float * bin_lower_bound = new float[max_bin];
// mean size for one bin // mean size for one bin
float mean_bin_size = sample_size / static_cast<float>(max_bin); float mean_bin_size = sample_size / static_cast<float>(max_bin);
int rest_sample_cnt = static_cast<int>(sample_size); int rest_sample_cnt = static_cast<int>(sample_size);
int cur_cnt_inbin = 0;
int bin_cnt = 0; int bin_cnt = 0;
bin_lower_bound[0] = distinct_values[0];
for (int i = 0; i < num_values - 1; ++i) { num_bin_ = max_bin;
rest_sample_cnt -= counts[i]; std::vector<float> upper_bounds(max_bin, std::numeric_limits<float>::infinity());
cur_cnt_inbin += counts[i]; std::vector<float> lower_bounds(max_bin, std::numeric_limits<float>::infinity());
// need a new bin // sort by count, descent
if (cur_cnt_inbin >= mean_bin_size) { Common::SortForPair(counts, distinct_values, 0, true);
bin_upper_bound_[bin_cnt] = distinct_values[i]; // fetch big slot as unique bin
if (bin_cnt == 0) { cnt_in_bin0 = cur_cnt_inbin; } while (counts[bin_cnt] > mean_bin_size) {
++bin_cnt; upper_bounds[bin_cnt] = distinct_values[bin_cnt];
bin_lower_bound[bin_cnt] = distinct_values[i + 1]; lower_bounds[bin_cnt] = distinct_values[bin_cnt];
cur_cnt_inbin = 0; rest_sample_cnt -= counts[bin_cnt];
mean_bin_size = rest_sample_cnt / static_cast<float>(max_bin - bin_cnt); ++bin_cnt;
}
// process reminder bins
if (bin_cnt < max_bin) {
// sort rest by values
Common::SortForPair<float, int>(distinct_values, counts, bin_cnt, false);
mean_bin_size = rest_sample_cnt / static_cast<float>(max_bin - bin_cnt);
lower_bounds[bin_cnt] = distinct_values[bin_cnt];
int cur_cnt_inbin = 0;
for (int i = bin_cnt; i < num_values - 1; ++i) {
rest_sample_cnt -= counts[i];
cur_cnt_inbin += counts[i];
// need a new bin
if (cur_cnt_inbin >= mean_bin_size) {
upper_bounds[bin_cnt] = distinct_values[i];
if (bin_cnt == 0) { cnt_in_bin0 = cur_cnt_inbin; }
++bin_cnt;
lower_bounds[bin_cnt] = distinct_values[i + 1];
if (bin_cnt >= max_bin - 1) break;
cur_cnt_inbin = 0;
mean_bin_size = rest_sample_cnt / static_cast<float>(max_bin - bin_cnt);
}
} }
cur_cnt_inbin += counts[num_values - 1];
} }
cur_cnt_inbin += counts[num_values - 1]; Common::SortForPair<float, float>(lower_bounds, upper_bounds, 0, false);
// update bin upper bound // update bin upper bound
for (int i = 0; i < bin_cnt; ++i) { bin_upper_bound_ = new float[bin_cnt];
bin_upper_bound_[i] = (bin_upper_bound_[i] + bin_lower_bound[i + 1]) / 2.0f; for (int i = 0; i < bin_cnt - 1; ++i) {
bin_upper_bound_[i] = (upper_bounds[i] + lower_bounds[i + 1]) / 2.0f;
} }
// last bin upper bound // last bin upper bound
bin_upper_bound_[bin_cnt] = std::numeric_limits<float>::infinity(); bin_upper_bound_[bin_cnt - 1] = std::numeric_limits<float>::infinity();
++bin_cnt;
delete[] bin_lower_bound; CHECK(bin_cnt <= max_bin);
// if no so much bin
if (bin_cnt < max_bin) {
// old bin data
float* tmp_bin_upper_bound = bin_upper_bound_;
num_bin_ = bin_cnt;
bin_upper_bound_ = new float[num_bin_];
// copy back
for (int i = 0; i < num_bin_; ++i) {
bin_upper_bound_[i] = tmp_bin_upper_bound[i];
}
// free old space
delete[] tmp_bin_upper_bound;
}
} }
delete[] distinct_values;
delete[] counts;
// check trival(num_bin_ == 1) feature // check trival(num_bin_ == 1) feature
if (num_bin_ <= 1) { if (num_bin_ <= 1) {
is_trival_ = true; is_trival_ = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment