Unverified Commit 9e1d7fa1 authored by Christoph Aymanns's avatar Christoph Aymanns Committed by GitHub
Browse files

enforce interaction constraints with monotone_constraints_method = intermediate/advanced (#4043)



* add test for interaction constraints and monotone constraints

* enforce interaction constraints in RecomputeBestSplitForLeaf

* code formatting

* code formatting

* move interaction constraint test to test_engine

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 55a31bfe
......@@ -677,7 +677,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_per_leaf_);
// update leave outputs if needed
for (auto leaf : leaves_need_update) {
RecomputeBestSplitForLeaf(leaf, &best_split_per_leaf_[leaf]);
RecomputeBestSplitForLeaf(tree, leaf, &best_split_per_leaf_[leaf]);
}
}
......@@ -768,7 +768,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le
return parent_output;
}
void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
FeatureHistogram* histogram_array_;
if (!histogram_pool_.Get(leaf, &histogram_array_)) {
Log::Warning(
......@@ -795,6 +795,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
OMP_INIT_EX();
// find splits
std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN();
......@@ -804,7 +805,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
}
const int tid = omp_get_thread_num();
int real_fidx = train_data_->RealFeatureIndex(feature_index);
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, true,
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, node_used_features[feature_index],
num_data, &leaf_splits, &bests[tid], parent_output);
OMP_LOOP_EX_END();
......
......@@ -128,7 +128,7 @@ class SerialTreeLearner: public TreeLearner {
void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);
void RecomputeBestSplitForLeaf(int leaf, SplitInfo* split);
void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split);
/*!
* \brief Some initial works before training
......
......@@ -1252,7 +1252,8 @@ def generate_trainset_for_monotone_constraints_tests(x3_to_category=True):
return trainset
def test_monotone_constraints():
@pytest.mark.parametrize("test_with_interaction_constraints", [True, False])
def test_monotone_constraints(test_with_interaction_constraints):
def is_increasing(y):
return (np.diff(y) >= 0.0).all()
......@@ -1273,28 +1274,69 @@ def test_monotone_constraints():
monotonically_increasing_y = learner.predict(monotonically_increasing_x)
monotonically_decreasing_x = np.column_stack((fixed_x, variable_x, fixed_x))
monotonically_decreasing_y = learner.predict(monotonically_decreasing_x)
non_monotone_x = np.column_stack((fixed_x,
non_monotone_x = np.column_stack(
(
fixed_x,
categorize(variable_x) if x3_to_category else variable_x))
fixed_x,
categorize(variable_x) if x3_to_category else variable_x,
)
)
non_monotone_y = learner.predict(non_monotone_x)
if not (is_increasing(monotonically_increasing_y)
if not (
is_increasing(monotonically_increasing_y)
and is_decreasing(monotonically_decreasing_y)
and is_non_monotone(non_monotone_y)):
and is_non_monotone(non_monotone_y)
):
return False
return True
def are_interactions_enforced(gbm, feature_sets):
def parse_tree_features(gbm):
# trees start at position 1.
tree_str = gbm.model_to_string().split("Tree")[1:]
feature_sets = []
for tree in tree_str:
# split_features are in 4th line.
features = tree.splitlines()[3].split("=")[1].split(" ")
features = set(f"Column_{f}" for f in features)
feature_sets.append(features)
return np.array(feature_sets)
def has_interaction(treef):
n = 0
for fs in feature_sets:
if len(treef.intersection(fs)) > 0:
n += 1
return n > 1
tree_features = parse_tree_features(gbm)
has_interaction_flag = np.array(
[has_interaction(treef) for treef in tree_features]
)
return not has_interaction_flag.any()
for test_with_categorical_variable in [True, False]:
trainset = generate_trainset_for_monotone_constraints_tests(test_with_categorical_variable)
trainset = generate_trainset_for_monotone_constraints_tests(
test_with_categorical_variable
)
for monotone_constraints_method in ["basic", "intermediate", "advanced"]:
params = {
'min_data': 20,
'num_leaves': 20,
'monotone_constraints': [1, -1, 0],
"min_data": 20,
"num_leaves": 20,
"monotone_constraints": [1, -1, 0],
"monotone_constraints_method": monotone_constraints_method,
"use_missing": False,
}
if test_with_interaction_constraints:
params["interaction_constraints"] = [[0], [1], [2]]
constrained_model = lgb.train(params, trainset)
assert is_correctly_constrained(constrained_model, test_with_categorical_variable)
assert is_correctly_constrained(
constrained_model, test_with_categorical_variable
)
if test_with_interaction_constraints:
feature_sets = [["Column_0"], ["Column_1"], "Column_2"]
assert are_interactions_enforced(constrained_model, feature_sets)
def test_monotone_penalty():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment