Unverified Commit 9e1d7fa1 authored by Christoph Aymanns's avatar Christoph Aymanns Committed by GitHub
Browse files

enforce interaction constraints with monotone_constraints_method = intermediate/advanced (#4043)



* add test for interaction constraints and monotone constraints

* enforce interaction constraints in RecomputeBestSplitForLeaf

* code formatting

* code formatting

* move interaction constraint test to test_engine

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 55a31bfe
...@@ -677,7 +677,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, ...@@ -677,7 +677,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_per_leaf_); best_split_per_leaf_);
// update leave outputs if needed // update leave outputs if needed
for (auto leaf : leaves_need_update) { for (auto leaf : leaves_need_update) {
RecomputeBestSplitForLeaf(leaf, &best_split_per_leaf_[leaf]); RecomputeBestSplitForLeaf(tree, leaf, &best_split_per_leaf_[leaf]);
} }
} }
...@@ -768,7 +768,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le ...@@ -768,7 +768,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le
return parent_output; return parent_output;
} }
void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) { void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
FeatureHistogram* histogram_array_; FeatureHistogram* histogram_array_;
if (!histogram_pool_.Get(leaf, &histogram_array_)) { if (!histogram_pool_.Get(leaf, &histogram_array_)) {
Log::Warning( Log::Warning(
...@@ -795,6 +795,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) { ...@@ -795,6 +795,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
OMP_INIT_EX(); OMP_INIT_EX();
// find splits // find splits
std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads) #pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) { for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
...@@ -804,7 +805,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) { ...@@ -804,7 +805,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
} }
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
int real_fidx = train_data_->RealFeatureIndex(feature_index); int real_fidx = train_data_->RealFeatureIndex(feature_index);
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, true, ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, node_used_features[feature_index],
num_data, &leaf_splits, &bests[tid], parent_output); num_data, &leaf_splits, &bests[tid], parent_output);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
......
...@@ -128,7 +128,7 @@ class SerialTreeLearner: public TreeLearner { ...@@ -128,7 +128,7 @@ class SerialTreeLearner: public TreeLearner {
void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time); void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);
void RecomputeBestSplitForLeaf(int leaf, SplitInfo* split); void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split);
/*! /*!
* \brief Some initial works before training * \brief Some initial works before training
......
...@@ -1252,7 +1252,8 @@ def generate_trainset_for_monotone_constraints_tests(x3_to_category=True): ...@@ -1252,7 +1252,8 @@ def generate_trainset_for_monotone_constraints_tests(x3_to_category=True):
return trainset return trainset
def test_monotone_constraints(): @pytest.mark.parametrize("test_with_interaction_constraints", [True, False])
def test_monotone_constraints(test_with_interaction_constraints):
def is_increasing(y): def is_increasing(y):
return (np.diff(y) >= 0.0).all() return (np.diff(y) >= 0.0).all()
...@@ -1273,28 +1274,69 @@ def test_monotone_constraints(): ...@@ -1273,28 +1274,69 @@ def test_monotone_constraints():
monotonically_increasing_y = learner.predict(monotonically_increasing_x) monotonically_increasing_y = learner.predict(monotonically_increasing_x)
monotonically_decreasing_x = np.column_stack((fixed_x, variable_x, fixed_x)) monotonically_decreasing_x = np.column_stack((fixed_x, variable_x, fixed_x))
monotonically_decreasing_y = learner.predict(monotonically_decreasing_x) monotonically_decreasing_y = learner.predict(monotonically_decreasing_x)
non_monotone_x = np.column_stack((fixed_x, non_monotone_x = np.column_stack(
fixed_x, (
categorize(variable_x) if x3_to_category else variable_x)) fixed_x,
fixed_x,
categorize(variable_x) if x3_to_category else variable_x,
)
)
non_monotone_y = learner.predict(non_monotone_x) non_monotone_y = learner.predict(non_monotone_x)
if not (is_increasing(monotonically_increasing_y) if not (
and is_decreasing(monotonically_decreasing_y) is_increasing(monotonically_increasing_y)
and is_non_monotone(non_monotone_y)): and is_decreasing(monotonically_decreasing_y)
and is_non_monotone(non_monotone_y)
):
return False return False
return True return True
def are_interactions_enforced(gbm, feature_sets):
def parse_tree_features(gbm):
# trees start at position 1.
tree_str = gbm.model_to_string().split("Tree")[1:]
feature_sets = []
for tree in tree_str:
# split_features are in 4th line.
features = tree.splitlines()[3].split("=")[1].split(" ")
features = set(f"Column_{f}" for f in features)
feature_sets.append(features)
return np.array(feature_sets)
def has_interaction(treef):
n = 0
for fs in feature_sets:
if len(treef.intersection(fs)) > 0:
n += 1
return n > 1
tree_features = parse_tree_features(gbm)
has_interaction_flag = np.array(
[has_interaction(treef) for treef in tree_features]
)
return not has_interaction_flag.any()
for test_with_categorical_variable in [True, False]: for test_with_categorical_variable in [True, False]:
trainset = generate_trainset_for_monotone_constraints_tests(test_with_categorical_variable) trainset = generate_trainset_for_monotone_constraints_tests(
test_with_categorical_variable
)
for monotone_constraints_method in ["basic", "intermediate", "advanced"]: for monotone_constraints_method in ["basic", "intermediate", "advanced"]:
params = { params = {
'min_data': 20, "min_data": 20,
'num_leaves': 20, "num_leaves": 20,
'monotone_constraints': [1, -1, 0], "monotone_constraints": [1, -1, 0],
"monotone_constraints_method": monotone_constraints_method, "monotone_constraints_method": monotone_constraints_method,
"use_missing": False, "use_missing": False,
} }
if test_with_interaction_constraints:
params["interaction_constraints"] = [[0], [1], [2]]
constrained_model = lgb.train(params, trainset) constrained_model = lgb.train(params, trainset)
assert is_correctly_constrained(constrained_model, test_with_categorical_variable) assert is_correctly_constrained(
constrained_model, test_with_categorical_variable
)
if test_with_interaction_constraints:
feature_sets = [["Column_0"], ["Column_1"], "Column_2"]
assert are_interactions_enforced(constrained_model, feature_sets)
def test_monotone_penalty(): def test_monotone_penalty():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment