"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "e5eafad2ba71e48c6cac31741f32e897479870de"
Unverified Commit 49ea824f authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix forced split (#2838)

parent 53137e25
......@@ -546,7 +546,13 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
// split tree, will return right leaf
*left_leaf = current_leaf;
auto next_leaf_id = tree->NextLeafId();
if (train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin) {
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
&current_split_info.threshold, 1,
current_split_info.default_left, next_leaf_id);
current_split_info.left_count = data_partition_->leaf_count(*left_leaf);
current_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
*right_leaf = tree->Split(current_leaf,
inner_feature_index,
current_split_info.feature,
......@@ -561,9 +567,6 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
current_split_info.default_left);
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
&current_split_info.threshold, 1,
current_split_info.default_left, *right_leaf);
} else {
std::vector<uint32_t> cat_bitset_inner = Common::ConstructBitset(
current_split_info.cat_threshold.data(), current_split_info.num_cat_threshold);
......@@ -574,6 +577,11 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
}
std::vector<uint32_t> cat_bitset = Common::ConstructBitset(
threshold_int.data(), current_split_info.num_cat_threshold);
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
current_split_info.default_left, next_leaf_id);
current_split_info.left_count = data_partition_->leaf_count(*left_leaf);
current_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
*right_leaf = tree->SplitCategorical(current_leaf,
inner_feature_index,
current_split_info.feature,
......@@ -589,11 +597,10 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
current_split_info.default_left, *right_leaf);
}
#ifdef DEBUG
CHECK(*right_leaf == next_leaf_id);
#endif
if (current_split_info.left_count < current_split_info.right_count) {
left_smaller = true;
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment