Unverified Commit 33a2f9ec authored by tongwu-msft's avatar tongwu-msft Committed by GitHub
Browse files

Always respect forced splits, even when feature_fraction < 1.0 (fixes #4601) (#4725)

* issue fix #4601

* fix issue 4601 it2

* add tests for issue 4601

* fix warning

* fix warning

* add new line at end

* remove last line at end

* fix lint warning

* address comments

* address comments

* address comments

* fix address

* address comments

* revert seed

* fix recursive force split issue

* fix build error

* fix lint warning
parent b1facf50
......@@ -11,6 +11,7 @@
#include <algorithm>
#include <queue>
#include <set>
#include <unordered_map>
#include <utility>
......@@ -322,10 +323,14 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
}
void SerialTreeLearner::FindBestSplits(const Tree* tree) {
FindBestSplits(tree, nullptr);
}
void SerialTreeLearner::FindBestSplits(const Tree* tree, const std::set<int>* force_features) {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (!col_sampler_.is_feature_used_bytree()[feature_index] && (force_features == nullptr || force_features->find(feature_index) == force_features->end())) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
......@@ -462,12 +467,14 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
bool left_smaller = true;
std::unordered_map<int, SplitInfo> forceSplitMap;
q.push(std::make_pair(left, *left_leaf));
// Histogram construction require parent features.
std::set<int> force_split_features = FindAllForceFeatures(*forced_split_json_);
while (!q.empty()) {
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
FindBestSplits(tree);
FindBestSplits(tree, &force_split_features);
}
// then, compute own splits
SplitInfo left_split;
SplitInfo right_split;
......@@ -561,6 +568,32 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
return result_count;
}
std::set<int> SerialTreeLearner::FindAllForceFeatures(Json force_split_leaf_setting) {
std::set<int> force_features;
std::queue<Json> force_split_leafs;
force_split_leafs.push(force_split_leaf_setting);
while (!force_split_leafs.empty()) {
Json split_leaf = force_split_leafs.front();
force_split_leafs.pop();
const int feature_index = split_leaf["feature"].int_value();
const int feature_inner_index = train_data_->InnerFeatureIndex(feature_index);
force_features.insert(feature_inner_index);
if (split_leaf.object_items().count("left") > 0) {
force_split_leafs.push(split_leaf["left"]);
}
if (split_leaf.object_items().count("right") > 0) {
force_split_leafs.push(split_leaf["right"]);
}
}
return force_features;
}
void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
int* right_leaf, bool update_cnt) {
Common::FunctionTimer fun_timer("SerialTreeLearner::SplitInner", global_timer);
......
......@@ -19,6 +19,7 @@
#include <memory>
#include <random>
#include <vector>
#include <set>
#include "col_sampler.hpp"
#include "data_partition.hpp"
......@@ -142,6 +143,8 @@ class SerialTreeLearner: public TreeLearner {
virtual void FindBestSplits(const Tree* tree);
virtual void FindBestSplits(const Tree* tree, const std::set<int>* force_features);
virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree*);
......@@ -165,6 +168,8 @@ class SerialTreeLearner: public TreeLearner {
int32_t ForceSplits(Tree* tree, int* left_leaf, int* right_leaf,
int* cur_depth);
std::set<int> FindAllForceFeatures(Json force_split_leaf_setting);
/*!
* \brief Get the number of data in a leaf
* \param leaf_idx The index of leaf
......
# coding: utf-8
import copy
import itertools
import json
import math
import pickle
import platform
......@@ -2887,3 +2888,40 @@ def test_dump_model_hook():
dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook))
assert "leaf_value" not in dumped_model_str
assert "LV" in dumped_model_str
def test_force_split_with_feature_fraction(tmp_path):
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
forced_split = {
"feature": 0,
"threshold": 0.5,
"right": {
"feature": 2,
"threshold": 10.0
}
}
tmp_split_file = tmp_path / "forced_split.json"
with open(tmp_split_file, "w") as f:
f.write(json.dumps(forced_split))
params = {
"objective": "regression",
"feature_fraction": 0.6,
"force_col_wise": True,
"feature_fraction_seed": 1,
"forcedsplits_filename": tmp_split_file
}
gbm = lgb.train(params, lgb_train)
ret = mean_absolute_error(y_test, gbm.predict(X_test))
assert ret < 2.0
tree_info = gbm.dump_model()["tree_info"]
assert len(tree_info) > 1
for tree in tree_info:
tree_structure = tree["tree_structure"]
assert tree_structure['split_feature'] == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment