Unverified Commit 49ea824f authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix forced split (#2838)

parent 53137e25
...@@ -546,7 +546,13 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json ...@@ -546,7 +546,13 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
// split tree, will return right leaf // split tree, will return right leaf
*left_leaf = current_leaf; *left_leaf = current_leaf;
auto next_leaf_id = tree->NextLeafId();
if (train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin) { if (train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin) {
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
&current_split_info.threshold, 1,
current_split_info.default_left, next_leaf_id);
current_split_info.left_count = data_partition_->leaf_count(*left_leaf);
current_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
*right_leaf = tree->Split(current_leaf, *right_leaf = tree->Split(current_leaf,
inner_feature_index, inner_feature_index,
current_split_info.feature, current_split_info.feature,
...@@ -561,9 +567,6 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json ...@@ -561,9 +567,6 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
static_cast<float>(current_split_info.gain), static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(), train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
current_split_info.default_left); current_split_info.default_left);
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
&current_split_info.threshold, 1,
current_split_info.default_left, *right_leaf);
} else { } else {
std::vector<uint32_t> cat_bitset_inner = Common::ConstructBitset( std::vector<uint32_t> cat_bitset_inner = Common::ConstructBitset(
current_split_info.cat_threshold.data(), current_split_info.num_cat_threshold); current_split_info.cat_threshold.data(), current_split_info.num_cat_threshold);
...@@ -574,6 +577,11 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json ...@@ -574,6 +577,11 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
} }
std::vector<uint32_t> cat_bitset = Common::ConstructBitset( std::vector<uint32_t> cat_bitset = Common::ConstructBitset(
threshold_int.data(), current_split_info.num_cat_threshold); threshold_int.data(), current_split_info.num_cat_threshold);
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
current_split_info.default_left, next_leaf_id);
current_split_info.left_count = data_partition_->leaf_count(*left_leaf);
current_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
*right_leaf = tree->SplitCategorical(current_leaf, *right_leaf = tree->SplitCategorical(current_leaf,
inner_feature_index, inner_feature_index,
current_split_info.feature, current_split_info.feature,
...@@ -589,11 +597,10 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json ...@@ -589,11 +597,10 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, const Json& forced_split_json
static_cast<double>(current_split_info.right_sum_hessian), static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain), static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type()); train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
current_split_info.default_left, *right_leaf);
} }
#ifdef DEBUG
CHECK(*right_leaf == next_leaf_id);
#endif
if (current_split_info.left_count < current_split_info.right_count) { if (current_split_info.left_count < current_split_info.right_count) {
left_smaller = true; left_smaller = true;
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment