Unverified Commit f84bfcf9 authored by Belinda Trotta's avatar Belinda Trotta Committed by GitHub
Browse files

Check feature indexes in forced split file (fixes #5517) (#5653)

parent 51edbda7
......@@ -14,6 +14,7 @@
#include <chrono>
#include <ctime>
#include <queue>
#include <sstream>
namespace LightGBM {
......@@ -138,6 +139,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
// get parser config file content
parser_config_str_ = train_data_->parser_config_str();
// check that forced splits does not use feature indices larger than dataset size
CheckForcedSplitFeatures();
// if need bagging, create buffer
data_sample_strategy_->ResetSampleConfig(config_.get(), true);
ResetGradientBuffers();
......@@ -155,6 +159,26 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
}
}
void GBDT::CheckForcedSplitFeatures() {
std::queue<Json> forced_split_nodes;
forced_split_nodes.push(forced_splits_json_);
while (!forced_split_nodes.empty()) {
Json node = forced_split_nodes.front();
forced_split_nodes.pop();
const int feature_index = node["feature"].int_value();
if (feature_index > max_feature_idx_) {
Log::Fatal("Forced splits file includes feature index %d, but maximum feature index in dataset is %d",
feature_index, max_feature_idx_);
}
if (node.object_items().count("left") > 0) {
forced_split_nodes.push(node["left"]);
}
if (node.object_items().count("right") > 0) {
forced_split_nodes.push(node["right"]);
}
}
}
void GBDT::AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) {
if (!train_data_->CheckAlign(*valid_data)) {
......
......@@ -58,6 +58,11 @@ class GBDT : public GBDTBase {
const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override;
/*!
* \brief Traverse the tree of forced splits and check that all indices are less than the number of features.
*/
void CheckForcedSplitFeatures();
/*!
* \brief Merge model from other boosting object. Will insert to the front of current boosting object
* \param other
......
......@@ -2887,6 +2887,25 @@ def test_node_level_subcol():
assert ret != ret2
def test_forced_split_feature_indices(tmp_path):
X, y = make_synthetic_regression()
forced_split = {
"feature": 0,
"threshold": 0.5,
"left": {"feature": X.shape[1], "threshold": 0.5},
}
tmp_split_file = tmp_path / "forced_split.json"
with open(tmp_split_file, "w") as f:
f.write(json.dumps(forced_split))
lgb_train = lgb.Dataset(X, y)
params = {
"objective": "regression",
"forcedsplits_filename": tmp_split_file
}
with pytest.raises(lgb.basic.LightGBMError, match="Forced splits file includes feature index"):
bst = lgb.train(params, lgb_train)
def test_forced_bins():
x = np.empty((100, 2))
x[:, 0] = np.arange(0, 1, 0.01)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment