Commit b0b0010a authored by Guolin Ke's avatar Guolin Ke
Browse files

merge from master

parents c04830a8 405f45a0
...@@ -116,21 +116,21 @@ public: ...@@ -116,21 +116,21 @@ public:
void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) override; void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) override;
/*! /*!
* \brief Predtion for one record without sigmoid transformation * \brief Prediction for one record without sigmoid transformation
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
std::vector<double> PredictRaw(const double* feature_values) const override; std::vector<double> PredictRaw(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with sigmoid transformation if enabled * \brief Prediction for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \return Prediction result for this record
*/ */
std::vector<double> Predict(const double* feature_values) const override; std::vector<double> Predict(const double* feature_values) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Prediction for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
...@@ -261,7 +261,7 @@ protected: ...@@ -261,7 +261,7 @@ protected:
std::vector<data_size_t> bag_data_indices_; std::vector<data_size_t> bag_data_indices_;
/*! \brief Number of in-bag data */ /*! \brief Number of in-bag data */
data_size_t bag_data_cnt_; data_size_t bag_data_cnt_;
/*! \brief Number of traning data */ /*! \brief Number of training data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */ /*! \brief Number of classes */
int num_class_; int num_class_;
...@@ -269,7 +269,7 @@ protected: ...@@ -269,7 +269,7 @@ protected:
Random random_; Random random_;
/*! /*!
* \brief Sigmoid parameter, used for prediction. * \brief Sigmoid parameter, used for prediction.
* if > 0 meas output score will transform by sigmoid function * if > 0 means output score will transform by sigmoid function
*/ */
double sigmoid_; double sigmoid_;
/*! \brief Index of label column */ /*! \brief Index of label column */
......
...@@ -94,20 +94,32 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in ...@@ -94,20 +94,32 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
} else { } else {
// mean size for one bin // mean size for one bin
double mean_bin_size = sample_size / static_cast<double>(max_bin); double mean_bin_size = sample_size / static_cast<double>(max_bin);
double static_mean_bin_size = mean_bin_size; int rest_bin_cnt = max_bin;
int rest_sample_cnt = static_cast<int>(sample_size);
std::vector<bool> is_big_count_value(num_values, false);
for (int i = 0; i < num_values; ++i) {
if (counts[i] >= mean_bin_size) {
is_big_count_value[i] = true;
--rest_bin_cnt;
rest_sample_cnt -= counts[i];
}
}
mean_bin_size = rest_sample_cnt / static_cast<double>(rest_bin_cnt);
std::vector<double> upper_bounds(max_bin, std::numeric_limits<double>::infinity()); std::vector<double> upper_bounds(max_bin, std::numeric_limits<double>::infinity());
std::vector<double> lower_bounds(max_bin, std::numeric_limits<double>::infinity()); std::vector<double> lower_bounds(max_bin, std::numeric_limits<double>::infinity());
int rest_sample_cnt = static_cast<int>(sample_size);
int bin_cnt = 0; int bin_cnt = 0;
lower_bounds[bin_cnt] = distinct_values[0]; lower_bounds[bin_cnt] = distinct_values[0];
int cur_cnt_inbin = 0; int cur_cnt_inbin = 0;
for (int i = 0; i < num_values - 1; ++i) { for (int i = 0; i < num_values - 1; ++i) {
if (!is_big_count_value[i]) {
rest_sample_cnt -= counts[i]; rest_sample_cnt -= counts[i];
}
cur_cnt_inbin += counts[i]; cur_cnt_inbin += counts[i];
// need a new bin // need a new bin
if (counts[i] >= static_mean_bin_size || cur_cnt_inbin >= mean_bin_size || if (is_big_count_value[i] || cur_cnt_inbin >= mean_bin_size ||
(counts[i + 1] >= static_mean_bin_size && cur_cnt_inbin >= std::max(1.0, mean_bin_size * 0.5f))) { (is_big_count_value[i + 1] && cur_cnt_inbin >= std::max(1.0, mean_bin_size * 0.5f))) {
upper_bounds[bin_cnt] = distinct_values[i]; upper_bounds[bin_cnt] = distinct_values[i];
if (bin_cnt == 0) { if (bin_cnt == 0) {
cnt_in_bin0 = cur_cnt_inbin; cnt_in_bin0 = cur_cnt_inbin;
...@@ -116,7 +128,10 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in ...@@ -116,7 +128,10 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
lower_bounds[bin_cnt] = distinct_values[i + 1]; lower_bounds[bin_cnt] = distinct_values[i + 1];
if (bin_cnt >= max_bin - 1) { break; } if (bin_cnt >= max_bin - 1) { break; }
cur_cnt_inbin = 0; cur_cnt_inbin = 0;
mean_bin_size = rest_sample_cnt / static_cast<double>(max_bin - bin_cnt); if (!is_big_count_value[i]) {
--rest_bin_cnt;
mean_bin_size = rest_sample_cnt / static_cast<double>(rest_bin_cnt);
}
} }
} }
// //
......
...@@ -279,7 +279,7 @@ inline VAL_T SparseBinIterator<VAL_T>::InnerGet(data_size_t idx) { ...@@ -279,7 +279,7 @@ inline VAL_T SparseBinIterator<VAL_T>::InnerGet(data_size_t idx) {
while (cur_pos_ < idx && i_delta_ < bin_data_->num_vals_) { while (cur_pos_ < idx && i_delta_ < bin_data_->num_vals_) {
bin_data_->NextNonzero(&i_delta_, &cur_pos_); bin_data_->NextNonzero(&i_delta_, &cur_pos_);
} }
if (cur_pos_ == idx && i_delta_ < bin_data_->num_vals_) { if (cur_pos_ == idx && i_delta_ < bin_data_->num_vals_ && i_delta_ >= 0) {
return bin_data_->vals_[i_delta_]; return bin_data_->vals_[i_delta_];
} else { } else {
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment