Commit b0b0010a authored by Guolin Ke's avatar Guolin Ke
Browse files

merge from master

parents c04830a8 405f45a0
......@@ -116,21 +116,21 @@ public:
void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) override;
/*!
* \brief Predtion for one record without sigmoid transformation
* \brief Prediction for one record without sigmoid transformation
* \param feature_values Feature value on this record
* \return Prediction result for this record
*/
std::vector<double> PredictRaw(const double* feature_values) const override;
/*!
* \brief Predtion for one record with sigmoid transformation if enabled
* \brief Prediction for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record
* \return Prediction result for this record
*/
std::vector<double> Predict(const double* feature_values) const override;
/*!
* \brief Predtion for one record with leaf index
* \brief Prediction for one record with leaf index
* \param feature_values Feature value on this record
* \return Predicted leaf index for this record
*/
......@@ -261,7 +261,7 @@ protected:
std::vector<data_size_t> bag_data_indices_;
/*! \brief Number of in-bag data */
data_size_t bag_data_cnt_;
/*! \brief Number of traning data */
/*! \brief Number of training data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
......@@ -269,7 +269,7 @@ protected:
Random random_;
/*!
* \brief Sigmoid parameter, used for prediction.
* if > 0 meas output score will transform by sigmoid function
* if > 0 means output score will transform by sigmoid function
*/
double sigmoid_;
/*! \brief Index of label column */
......
......@@ -94,20 +94,32 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
} else {
// mean size for one bin
double mean_bin_size = sample_size / static_cast<double>(max_bin);
double static_mean_bin_size = mean_bin_size;
int rest_bin_cnt = max_bin;
int rest_sample_cnt = static_cast<int>(sample_size);
std::vector<bool> is_big_count_value(num_values, false);
for (int i = 0; i < num_values; ++i) {
if (counts[i] >= mean_bin_size) {
is_big_count_value[i] = true;
--rest_bin_cnt;
rest_sample_cnt -= counts[i];
}
}
mean_bin_size = rest_sample_cnt / static_cast<double>(rest_bin_cnt);
std::vector<double> upper_bounds(max_bin, std::numeric_limits<double>::infinity());
std::vector<double> lower_bounds(max_bin, std::numeric_limits<double>::infinity());
int rest_sample_cnt = static_cast<int>(sample_size);
int bin_cnt = 0;
lower_bounds[bin_cnt] = distinct_values[0];
int cur_cnt_inbin = 0;
for (int i = 0; i < num_values - 1; ++i) {
if (!is_big_count_value[i]) {
rest_sample_cnt -= counts[i];
}
cur_cnt_inbin += counts[i];
// need a new bin
if (counts[i] >= static_mean_bin_size || cur_cnt_inbin >= mean_bin_size ||
(counts[i + 1] >= static_mean_bin_size && cur_cnt_inbin >= std::max(1.0, mean_bin_size * 0.5f))) {
if (is_big_count_value[i] || cur_cnt_inbin >= mean_bin_size ||
(is_big_count_value[i + 1] && cur_cnt_inbin >= std::max(1.0, mean_bin_size * 0.5f))) {
upper_bounds[bin_cnt] = distinct_values[i];
if (bin_cnt == 0) {
cnt_in_bin0 = cur_cnt_inbin;
......@@ -116,7 +128,10 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
lower_bounds[bin_cnt] = distinct_values[i + 1];
if (bin_cnt >= max_bin - 1) { break; }
cur_cnt_inbin = 0;
mean_bin_size = rest_sample_cnt / static_cast<double>(max_bin - bin_cnt);
if (!is_big_count_value[i]) {
--rest_bin_cnt;
mean_bin_size = rest_sample_cnt / static_cast<double>(rest_bin_cnt);
}
}
}
//
......
......@@ -279,7 +279,7 @@ inline VAL_T SparseBinIterator<VAL_T>::InnerGet(data_size_t idx) {
while (cur_pos_ < idx && i_delta_ < bin_data_->num_vals_) {
bin_data_->NextNonzero(&i_delta_, &cur_pos_);
}
if (cur_pos_ == idx && i_delta_ < bin_data_->num_vals_) {
if (cur_pos_ == idx && i_delta_ < bin_data_->num_vals_ && i_delta_ >= 0) {
return bin_data_->vals_[i_delta_];
} else {
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment