Unverified Commit a3a353d6 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

At least 2 features are chosen in subcolumn (#2409)

* at least 2 features are chosen in subcolumn

* Update serial_tree_learner.cpp

* rounding
parent a119639a
......@@ -277,9 +277,10 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
return ret;
}
std::memset(ret.data(), 0, sizeof(int8_t) * num_features_);
const int min_used_features = std::min(2, static_cast<int>(valid_feature_indices_.size()));
if (is_tree_level) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size() * config_->feature_fraction);
used_feature_cnt = std::max(used_feature_cnt, 1);
int used_feature_cnt = static_cast<int>(std::round(valid_feature_indices_.size() * config_->feature_fraction));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
used_feature_indices_ = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(used_feature_indices_.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
......@@ -290,8 +291,8 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
ret[inner_feature_index] = 1;
}
} else if(used_feature_indices_.size() <= 0) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size() * config_->feature_fraction_bynode);
used_feature_cnt = std::max(used_feature_cnt, 1);
int used_feature_cnt = static_cast<int>(std::round(valid_feature_indices_.size() * config_->feature_fraction_bynode));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
auto sampled_indices = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
......@@ -302,8 +303,8 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
ret[inner_feature_index] = 1;
}
} else {
int used_feature_cnt = static_cast<int>(used_feature_indices_.size() * config_->feature_fraction_bynode);
used_feature_cnt = std::max(used_feature_cnt, 1);
int used_feature_cnt = static_cast<int>(std::round(used_feature_indices_.size() * config_->feature_fraction_bynode));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
auto sampled_indices = random_.Sample(static_cast<int>(used_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment