"...git@developer.sourcefind.cn:tsoc/hg-misc-tools.git" did not exist on "bf9049710e08a3a53cf19f574b728f270f27b4ce"
Unverified Commit a3a353d6 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

At least 2 features are chosen in subcolumn (#2409)

* at least 2 features are chosen in subcolumn

* Update serial_tree_learner.cpp

* rounding
parent a119639a
...@@ -277,9 +277,10 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) { ...@@ -277,9 +277,10 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
return ret; return ret;
} }
std::memset(ret.data(), 0, sizeof(int8_t) * num_features_); std::memset(ret.data(), 0, sizeof(int8_t) * num_features_);
const int min_used_features = std::min(2, static_cast<int>(valid_feature_indices_.size()));
if (is_tree_level) { if (is_tree_level) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size() * config_->feature_fraction); int used_feature_cnt = static_cast<int>(std::round(valid_feature_indices_.size() * config_->feature_fraction));
used_feature_cnt = std::max(used_feature_cnt, 1); used_feature_cnt = std::max(used_feature_cnt, min_used_features);
used_feature_indices_ = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt); used_feature_indices_ = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(used_feature_indices_.size()); int omp_loop_size = static_cast<int>(used_feature_indices_.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024) #pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
...@@ -290,8 +291,8 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) { ...@@ -290,8 +291,8 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
ret[inner_feature_index] = 1; ret[inner_feature_index] = 1;
} }
} else if(used_feature_indices_.size() <= 0) { } else if(used_feature_indices_.size() <= 0) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size() * config_->feature_fraction_bynode); int used_feature_cnt = static_cast<int>(std::round(valid_feature_indices_.size() * config_->feature_fraction_bynode));
used_feature_cnt = std::max(used_feature_cnt, 1); used_feature_cnt = std::max(used_feature_cnt, min_used_features);
auto sampled_indices = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt); auto sampled_indices = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size()); int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024) #pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
...@@ -302,8 +303,8 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) { ...@@ -302,8 +303,8 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
ret[inner_feature_index] = 1; ret[inner_feature_index] = 1;
} }
} else { } else {
int used_feature_cnt = static_cast<int>(used_feature_indices_.size() * config_->feature_fraction_bynode); int used_feature_cnt = static_cast<int>(std::round(used_feature_indices_.size() * config_->feature_fraction_bynode));
used_feature_cnt = std::max(used_feature_cnt, 1); used_feature_cnt = std::max(used_feature_cnt, min_used_features);
auto sampled_indices = random_.Sample(static_cast<int>(used_feature_indices_.size()), used_feature_cnt); auto sampled_indices = random_.Sample(static_cast<int>(used_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size()); int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024) #pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment