monotone_constraints.hpp 46.7 KB
Newer Older
Nikita Titov's avatar
Nikita Titov committed
1
2
/*!
 * Copyright (c) 2020 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
Nikita Titov's avatar
Nikita Titov committed
5
 */
6
7
#ifndef LIGHTGBM_SRC_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_
#define LIGHTGBM_SRC_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_
8

9
10
#include <LightGBM/tree.h>

11
12
#include <algorithm>
#include <cstdint>
13
#include <limits>
14
#include <memory>
15
#include <utility>
Nikita Titov's avatar
Nikita Titov committed
16
#include <vector>
17

18
19
#include "split_info.hpp"

20
21
namespace LightGBM {

22
23
24
class LeafConstraintsBase;

struct BasicConstraint {
25
26
27
  double min = -std::numeric_limits<double>::max();
  double max = std::numeric_limits<double>::max();

28
  BasicConstraint(double min, double max) : min(min), max(max) {}
29

30
31
32
33
34
35
36
37
38
39
40
41
42
  BasicConstraint() = default;
};

struct FeatureConstraint {
  virtual void InitCumulativeConstraints(bool) const {}
  virtual void Update(int) const {}
  virtual BasicConstraint LeftToBasicConstraint() const = 0;
  virtual BasicConstraint RightToBasicConstraint() const = 0;
  virtual bool ConstraintDifferentDependingOnThreshold() const = 0;
  virtual ~FeatureConstraint() {}
};

struct ConstraintEntry {
43
  virtual ~ConstraintEntry() {}
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  virtual void Reset() = 0;
  virtual void UpdateMin(double new_min) = 0;
  virtual void UpdateMax(double new_max) = 0;
  virtual bool UpdateMinAndReturnBoolIfChanged(double new_min) = 0;
  virtual bool UpdateMaxAndReturnBoolIfChanged(double new_max) = 0;
  virtual ConstraintEntry *clone() const = 0;

  virtual void RecomputeConstraintsIfNeeded(LeafConstraintsBase *, int, int,
                                            uint32_t) {}

  virtual FeatureConstraint *GetFeatureConstraint(int feature_index) = 0;
};

// used by both BasicLeafConstraints and IntermediateLeafConstraints
struct BasicConstraintEntry : ConstraintEntry,
                              FeatureConstraint,
                              BasicConstraint {
  bool ConstraintDifferentDependingOnThreshold() const final { return false; }

  BasicConstraintEntry *clone() const final {
    return new BasicConstraintEntry(*this);
  };

  void Reset() final {
68
69
70
71
    min = -std::numeric_limits<double>::max();
    max = std::numeric_limits<double>::max();
  }

72
  void UpdateMin(double new_min) final { min = std::max(new_min, min); }
73

74
  void UpdateMax(double new_max) final { max = std::min(new_max, max); }
75

76
  bool UpdateMinAndReturnBoolIfChanged(double new_min) final {
77
78
79
80
81
82
83
    if (new_min > min) {
      min = new_min;
      return true;
    }
    return false;
  }

84
  bool UpdateMaxAndReturnBoolIfChanged(double new_max) final {
85
86
87
88
89
90
    if (new_max < max) {
      max = new_max;
      return true;
    }
    return false;
  }
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

  BasicConstraint LeftToBasicConstraint() const final { return *this; }

  BasicConstraint RightToBasicConstraint() const final { return *this; }

  FeatureConstraint *GetFeatureConstraint(int) final { return this; }
};

struct FeatureMinOrMaxConstraints {
  std::vector<double> constraints;
  // the constraint number i is valid on the slice
  // [thresholds[i]:threshold[i+1])
  // if threshold[i+1] does not exist, then it is valid for thresholds following
  // threshold[i]
  std::vector<uint32_t> thresholds;

  FeatureMinOrMaxConstraints() {
    constraints.reserve(32);
    thresholds.reserve(32);
  }

  size_t Size() const { return thresholds.size(); }

  explicit FeatureMinOrMaxConstraints(double extremum) {
    constraints.reserve(32);
    thresholds.reserve(32);

    constraints.push_back(extremum);
    thresholds.push_back(0);
  }

  void Reset(double extremum) {
    constraints.resize(1);
    constraints[0] = extremum;
    thresholds.resize(1);
    thresholds[0] = 0;
  }

  void UpdateMin(double min) {
    for (size_t j = 0; j < constraints.size(); ++j) {
      if (min > constraints[j]) {
        constraints[j] = min;
      }
    }
  }

  void UpdateMax(double max) {
    for (size_t j = 0; j < constraints.size(); ++j) {
      if (max < constraints[j]) {
        constraints[j] = max;
      }
    }
  }
};

struct CumulativeFeatureConstraint {
  std::vector<uint32_t> thresholds_min_constraints;
  std::vector<uint32_t> thresholds_max_constraints;
  std::vector<double> cumulative_min_constraints_left_to_right;
  std::vector<double> cumulative_min_constraints_right_to_left;
  std::vector<double> cumulative_max_constraints_left_to_right;
  std::vector<double> cumulative_max_constraints_right_to_left;
  size_t index_min_constraints_left_to_right;
  size_t index_min_constraints_right_to_left;
  size_t index_max_constraints_left_to_right;
  size_t index_max_constraints_right_to_left;

  static void CumulativeExtremum(
      const double &(*extremum_function)(const double &, const double &),
      bool is_direction_from_left_to_right,
      std::vector<double>* cumulative_extremum) {
    if (cumulative_extremum->size() == 1) {
      return;
    }

#ifdef DEBUG
    CHECK_NE(cumulative_extremum->size(), 0);
#endif

    size_t n_exts = cumulative_extremum->size();
    int step = is_direction_from_left_to_right ? 1 : -1;
    size_t start = is_direction_from_left_to_right ? 0 : n_exts - 1;
    size_t end = is_direction_from_left_to_right ? n_exts - 1 : 0;

    for (auto i = start; i != end; i = i + step) {
      (*cumulative_extremum)[i + step] = extremum_function(
          (*cumulative_extremum)[i + step], (*cumulative_extremum)[i]);
    }
  }

  CumulativeFeatureConstraint() = default;

  CumulativeFeatureConstraint(FeatureMinOrMaxConstraints min_constraints,
                              FeatureMinOrMaxConstraints max_constraints,
                              bool REVERSE) {
    thresholds_min_constraints = min_constraints.thresholds;
    thresholds_max_constraints = max_constraints.thresholds;
    cumulative_min_constraints_left_to_right = min_constraints.constraints;
    cumulative_min_constraints_right_to_left = min_constraints.constraints;
    cumulative_max_constraints_left_to_right = max_constraints.constraints;
    cumulative_max_constraints_right_to_left = max_constraints.constraints;

    const double &(*min)(const double &, const double &) = std::min<double>;
    const double &(*max)(const double &, const double &) = std::max<double>;
    CumulativeExtremum(max, true, &cumulative_min_constraints_left_to_right);
    CumulativeExtremum(max, false, &cumulative_min_constraints_right_to_left);
    CumulativeExtremum(min, true, &cumulative_max_constraints_left_to_right);
    CumulativeExtremum(min, false, &cumulative_max_constraints_right_to_left);

    if (REVERSE) {
      index_min_constraints_left_to_right =
          thresholds_min_constraints.size() - 1;
      index_min_constraints_right_to_left =
          thresholds_min_constraints.size() - 1;
      index_max_constraints_left_to_right =
          thresholds_max_constraints.size() - 1;
      index_max_constraints_right_to_left =
          thresholds_max_constraints.size() - 1;
    } else {
      index_min_constraints_left_to_right = 0;
      index_min_constraints_right_to_left = 0;
      index_max_constraints_left_to_right = 0;
      index_max_constraints_right_to_left = 0;
    }
  }

  void Update(int threshold) {
    while (
        static_cast<int>(
            thresholds_min_constraints[index_min_constraints_left_to_right]) >
        threshold - 1) {
      index_min_constraints_left_to_right -= 1;
    }
    while (
        static_cast<int>(
            thresholds_min_constraints[index_min_constraints_right_to_left]) >
        threshold) {
      index_min_constraints_right_to_left -= 1;
    }
    while (
        static_cast<int>(
            thresholds_max_constraints[index_max_constraints_left_to_right]) >
        threshold - 1) {
      index_max_constraints_left_to_right -= 1;
    }
    while (
        static_cast<int>(
            thresholds_max_constraints[index_max_constraints_right_to_left]) >
        threshold) {
      index_max_constraints_right_to_left -= 1;
    }
  }

  double GetRightMin() const {
    return cumulative_min_constraints_right_to_left
        [index_min_constraints_right_to_left];
  }
  double GetRightMax() const {
    return cumulative_max_constraints_right_to_left
        [index_max_constraints_right_to_left];
  }
  double GetLeftMin() const {
    return cumulative_min_constraints_left_to_right
        [index_min_constraints_left_to_right];
  }
  double GetLeftMax() const {
    return cumulative_max_constraints_left_to_right
        [index_max_constraints_left_to_right];
  }
};

struct AdvancedFeatureConstraints : FeatureConstraint {
  FeatureMinOrMaxConstraints min_constraints;
  FeatureMinOrMaxConstraints max_constraints;
  mutable CumulativeFeatureConstraint cumulative_feature_constraint;
  bool min_constraints_to_be_recomputed = false;
  bool max_constraints_to_be_recomputed = false;

  void InitCumulativeConstraints(bool REVERSE) const final {
    cumulative_feature_constraint =
        CumulativeFeatureConstraint(min_constraints, max_constraints, REVERSE);
  }

  void Update(int threshold) const final {
    cumulative_feature_constraint.Update(threshold);
  }

  FeatureMinOrMaxConstraints &GetMinConstraints() { return min_constraints; }

  FeatureMinOrMaxConstraints &GetMaxConstraints() { return max_constraints; }

  bool ConstraintDifferentDependingOnThreshold() const final {
    return min_constraints.Size() > 1 || max_constraints.Size() > 1;
  }

  BasicConstraint RightToBasicConstraint() const final {
    return BasicConstraint(cumulative_feature_constraint.GetRightMin(),
                           cumulative_feature_constraint.GetRightMax());
  }

  BasicConstraint LeftToBasicConstraint() const final {
    return BasicConstraint(cumulative_feature_constraint.GetLeftMin(),
                           cumulative_feature_constraint.GetLeftMax());
  }

  void Reset() {
    min_constraints.Reset(-std::numeric_limits<double>::max());
    max_constraints.Reset(std::numeric_limits<double>::max());
  }

  void UpdateMax(double new_max, bool trigger_a_recompute) {
    if (trigger_a_recompute) {
      max_constraints_to_be_recomputed = true;
    }
    max_constraints.UpdateMax(new_max);
  }

  bool FeatureMaxConstraintsToBeUpdated() {
    return max_constraints_to_be_recomputed;
  }

  bool FeatureMinConstraintsToBeUpdated() {
    return min_constraints_to_be_recomputed;
  }

  void ResetUpdates() {
    min_constraints_to_be_recomputed = false;
    max_constraints_to_be_recomputed = false;
  }

  void UpdateMin(double new_min, bool trigger_a_recompute) {
    if (trigger_a_recompute) {
      min_constraints_to_be_recomputed = true;
    }
    min_constraints.UpdateMin(new_min);
  }
327
328
329
330
331
};

class LeafConstraintsBase {
 public:
  virtual ~LeafConstraintsBase() {}
332
333
  virtual const ConstraintEntry* Get(int leaf_idx) = 0;
  virtual FeatureConstraint* GetFeatureConstraint(int leaf_idx, int feature_index) = 0;
334
  virtual void Reset() = 0;
335
  virtual void BeforeSplit(int leaf, int new_leaf,
336
337
                           int8_t monotone_type) = 0;
  virtual std::vector<int> Update(
338
      bool is_numerical_split,
339
340
341
342
      int leaf, int new_leaf, int8_t monotone_type, double right_output,
      double left_output, int split_feature, const SplitInfo& split_info,
      const std::vector<SplitInfo>& best_split_per_leaf) = 0;

343
344
345
346
347
348
349
350
351
352
353
354
355
  virtual void GoUpToFindConstrainingLeaves(
      int, int,
      std::vector<int>*,
      std::vector<uint32_t>*,
      std::vector<bool>*,
      FeatureMinOrMaxConstraints*, bool ,
      uint32_t, uint32_t, uint32_t) {}

  virtual void RecomputeConstraintsIfNeeded(
      LeafConstraintsBase *constraints_,
      int feature_for_constraint, int leaf_idx, uint32_t it_end) = 0;

  inline static LeafConstraintsBase* Create(const Config* config, int num_leaves, int num_features);
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

  double ComputeMonotoneSplitGainPenalty(int leaf_index, double penalization) {
    int depth = tree_->leaf_depth(leaf_index);
    if (penalization >= depth + 1.) {
      return kEpsilon;
    }
    if (penalization <= 1.) {
      return 1. - penalization / pow(2., depth) + kEpsilon;
    }
    return 1. - pow(2, penalization - 1. - depth) + kEpsilon;
  }

  void ShareTreePointer(const Tree* tree) {
    tree_ = tree;
  }

372
 protected:
373
  const Tree* tree_;
374
375
};

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
// used by AdvancedLeafConstraints
struct AdvancedConstraintEntry : ConstraintEntry {
  std::vector<AdvancedFeatureConstraints> constraints;

  AdvancedConstraintEntry *clone() const final {
    return new AdvancedConstraintEntry(*this);
  };

  void RecomputeConstraintsIfNeeded(LeafConstraintsBase *constraints_,
                                    int feature_for_constraint, int leaf_idx,
                                    uint32_t it_end) final {
    if (constraints[feature_for_constraint]
            .FeatureMinConstraintsToBeUpdated() ||
        constraints[feature_for_constraint]
            .FeatureMaxConstraintsToBeUpdated()) {
      FeatureMinOrMaxConstraints &constraints_to_be_updated =
          constraints[feature_for_constraint].FeatureMinConstraintsToBeUpdated()
              ? constraints[feature_for_constraint].GetMinConstraints()
              : constraints[feature_for_constraint].GetMaxConstraints();

      constraints_to_be_updated.Reset(
          constraints[feature_for_constraint].FeatureMinConstraintsToBeUpdated()
              ? -std::numeric_limits<double>::max()
              : std::numeric_limits<double>::max());

      std::vector<int> features_of_splits_going_up_from_original_leaf =
          std::vector<int>();
      std::vector<uint32_t> thresholds_of_splits_going_up_from_original_leaf =
          std::vector<uint32_t>();
      std::vector<bool> was_original_leaf_right_child_of_split =
          std::vector<bool>();
      constraints_->GoUpToFindConstrainingLeaves(
          feature_for_constraint, leaf_idx,
          &features_of_splits_going_up_from_original_leaf,
          &thresholds_of_splits_going_up_from_original_leaf,
          &was_original_leaf_right_child_of_split, &constraints_to_be_updated,
          constraints[feature_for_constraint]
              .FeatureMinConstraintsToBeUpdated(),
          0, it_end, it_end);
      constraints[feature_for_constraint].ResetUpdates();
    }
  }

  // for each feature, an array of constraints needs to be stored
  explicit AdvancedConstraintEntry(int num_features) {
    constraints.resize(num_features);
  }

  void Reset() final {
    for (size_t i = 0; i < constraints.size(); ++i) {
      constraints[i].Reset();
    }
  }

  void UpdateMin(double new_min) final {
    for (size_t i = 0; i < constraints.size(); ++i) {
      constraints[i].UpdateMin(new_min, false);
    }
  }

  void UpdateMax(double new_max) final {
    for (size_t i = 0; i < constraints.size(); ++i) {
      constraints[i].UpdateMax(new_max, false);
    }
  }

  bool UpdateMinAndReturnBoolIfChanged(double new_min) final {
    for (size_t i = 0; i < constraints.size(); ++i) {
      constraints[i].UpdateMin(new_min, true);
    }
    // even if nothing changed, this could have been unconstrained so it needs
    // to be recomputed from the beginning
    return true;
  }

  bool UpdateMaxAndReturnBoolIfChanged(double new_max) final {
    for (size_t i = 0; i < constraints.size(); ++i) {
      constraints[i].UpdateMax(new_max, true);
    }
    // even if nothing changed, this could have been unconstrained so it needs
    // to be recomputed from the beginning
    return true;
  }

  FeatureConstraint *GetFeatureConstraint(int feature_index) final {
    return &constraints[feature_index];
  }
};

465
class BasicLeafConstraints : public LeafConstraintsBase {
466
 public:
467
  explicit BasicLeafConstraints(int num_leaves) : num_leaves_(num_leaves) {
468
    for (int i = 0; i < num_leaves; ++i) {
469
      entries_.emplace_back(new BasicConstraintEntry());
470
    }
471
  }
Nikita Titov's avatar
Nikita Titov committed
472

473
  void Reset() override {
474
    for (auto& entry : entries_) {
475
      entry->Reset();
476
477
    }
  }
Nikita Titov's avatar
Nikita Titov committed
478

479
480
481
482
483
  void RecomputeConstraintsIfNeeded(
      LeafConstraintsBase* constraints_,
      int feature_for_constraint, int leaf_idx, uint32_t it_end) override {
    entries_[~leaf_idx]->RecomputeConstraintsIfNeeded(constraints_, feature_for_constraint, leaf_idx, it_end);
  }
484

485
486
487
  void BeforeSplit(int, int, int8_t) override {}

  std::vector<int> Update(bool is_numerical_split, int leaf, int new_leaf,
488
489
490
                          int8_t monotone_type, double right_output,
                          double left_output, int, const SplitInfo& ,
                          const std::vector<SplitInfo>&) override {
491
    entries_[new_leaf].reset(entries_[leaf]->clone());
492
493
494
    if (is_numerical_split) {
      double mid = (left_output + right_output) / 2.0f;
      if (monotone_type < 0) {
495
496
        entries_[leaf]->UpdateMin(mid);
        entries_[new_leaf]->UpdateMax(mid);
497
      } else if (monotone_type > 0) {
498
499
        entries_[leaf]->UpdateMax(mid);
        entries_[new_leaf]->UpdateMin(mid);
500
501
      }
    }
502
    return std::vector<int>();
503
504
  }

505
  const ConstraintEntry* Get(int leaf_idx) override { return entries_[leaf_idx].get(); }
506
507
508
509

  FeatureConstraint* GetFeatureConstraint(int leaf_idx, int feature_index) final {
    return entries_[leaf_idx]->GetFeatureConstraint(feature_index);
  }
510

511
 protected:
512
  int num_leaves_;
513
  std::vector<std::unique_ptr<ConstraintEntry>> entries_;
514
515
};

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
class IntermediateLeafConstraints : public BasicLeafConstraints {
 public:
  explicit IntermediateLeafConstraints(const Config* config, int num_leaves)
      : BasicLeafConstraints(num_leaves), config_(config) {
    leaf_is_in_monotone_subtree_.resize(num_leaves_, false);
    node_parent_.resize(num_leaves_ - 1, -1);
    leaves_to_update_.reserve(num_leaves_);
  }

  void Reset() override {
    BasicLeafConstraints::Reset();
    std::fill_n(leaf_is_in_monotone_subtree_.begin(), num_leaves_, false);
    std::fill_n(node_parent_.begin(), num_leaves_ - 1, -1);
    leaves_to_update_.clear();
  }

532
  void BeforeSplit(int leaf, int new_leaf,
533
534
535
536
537
538
539
540
541
                   int8_t monotone_type) override {
    if (monotone_type != 0 || leaf_is_in_monotone_subtree_[leaf]) {
      leaf_is_in_monotone_subtree_[leaf] = true;
      leaf_is_in_monotone_subtree_[new_leaf] = true;
    }
#ifdef DEBUG
    CHECK_GE(new_leaf - 1, 0);
    CHECK_LT(static_cast<size_t>(new_leaf - 1), node_parent_.size());
#endif
542
    node_parent_[new_leaf - 1] = tree_->leaf_parent(leaf);
543
544
545
546
547
  }

  void UpdateConstraintsWithOutputs(bool is_numerical_split, int leaf,
                                    int new_leaf, int8_t monotone_type,
                                    double right_output, double left_output) {
548
    entries_[new_leaf].reset(entries_[leaf]->clone());
549
550
    if (is_numerical_split) {
      if (monotone_type < 0) {
551
552
        entries_[leaf]->UpdateMin(right_output);
        entries_[new_leaf]->UpdateMax(left_output);
553
      } else if (monotone_type > 0) {
554
555
        entries_[leaf]->UpdateMax(right_output);
        entries_[new_leaf]->UpdateMin(left_output);
556
557
558
559
      }
    }
  }

560
  std::vector<int> Update(bool is_numerical_split, int leaf,
561
562
563
                          int new_leaf, int8_t monotone_type,
                          double right_output, double left_output,
                          int split_feature, const SplitInfo& split_info,
564
                          const std::vector<SplitInfo>& best_split_per_leaf) final {
565
566
567
568
569
570
    leaves_to_update_.clear();
    if (leaf_is_in_monotone_subtree_[leaf]) {
      UpdateConstraintsWithOutputs(is_numerical_split, leaf, new_leaf,
                                   monotone_type, right_output, left_output);

      // Initialize variables to store information while going up the tree
571
      int depth = tree_->leaf_depth(new_leaf) - 1;
572
573
574
575
576
577
578
579
580

      std::vector<int> features_of_splits_going_up_from_original_leaf;
      std::vector<uint32_t> thresholds_of_splits_going_up_from_original_leaf;
      std::vector<bool> was_original_leaf_right_child_of_split;

      features_of_splits_going_up_from_original_leaf.reserve(depth);
      thresholds_of_splits_going_up_from_original_leaf.reserve(depth);
      was_original_leaf_right_child_of_split.reserve(depth);

581
      GoUpToFindLeavesToUpdate(tree_->leaf_parent(new_leaf),
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
                               &features_of_splits_going_up_from_original_leaf,
                               &thresholds_of_splits_going_up_from_original_leaf,
                               &was_original_leaf_right_child_of_split,
                               split_feature, split_info, split_info.threshold,
                               best_split_per_leaf);
    }
    return leaves_to_update_;
  }

  bool OppositeChildShouldBeUpdated(
      bool is_split_numerical,
      const std::vector<int>& features_of_splits_going_up_from_original_leaf,
      int inner_feature,
      const std::vector<bool>& was_original_leaf_right_child_of_split,
      bool is_in_right_child) {
    // if the split is categorical, it is not handled by this optimisation,
    // so the code will have to go down in the other child subtree to see if
    // there are leaves to update
    // even though it may sometimes be unnecessary
    if (is_split_numerical) {
      // only branches containing leaves that are contiguous to the original
      // leaf need to be updated
      // therefore, for the same feature, there is no use going down from the
      // second time going up on the right (or on the left)
      for (size_t split_idx = 0;
           split_idx < features_of_splits_going_up_from_original_leaf.size();
           ++split_idx) {
        if (features_of_splits_going_up_from_original_leaf[split_idx] ==
                inner_feature &&
            (was_original_leaf_right_child_of_split[split_idx] ==
             is_in_right_child)) {
613
          return false;
614
615
        }
      }
616
617
618
      return true;
    } else {
      return false;
619
620
621
622
623
624
    }
  }

  // Recursive function that goes up the tree, and then down to find leaves that
  // have constraints to be updated
  void GoUpToFindLeavesToUpdate(
625
      int node_idx,
626
627
628
629
630
631
632
633
634
635
636
637
      std::vector<int>* features_of_splits_going_up_from_original_leaf,
      std::vector<uint32_t>* thresholds_of_splits_going_up_from_original_leaf,
      std::vector<bool>* was_original_leaf_right_child_of_split,
      int split_feature, const SplitInfo& split_info, uint32_t split_threshold,
      const std::vector<SplitInfo>& best_split_per_leaf) {
#ifdef DEBUG
    CHECK_GE(node_idx, 0);
    CHECK_LT(static_cast<size_t>(node_idx), node_parent_.size());
#endif
    int parent_idx = node_parent_[node_idx];
    // if not at the root
    if (parent_idx != -1) {
638
639
      int inner_feature = tree_->split_feature_inner(parent_idx);
      int feature = tree_->split_feature(parent_idx);
640
      int8_t monotone_type = config_->monotone_constraints[feature];
641
642
      bool is_in_right_child = tree_->right_child(parent_idx) == node_idx;
      bool is_split_numerical = tree_->IsNumericalSplit(parent_idx);
643
644
645
646
647
648
649
650
651
652
653
654
655
656

      // this is just an optimisation not to waste time going down in subtrees
      // where there won't be any leaf to update
      bool opposite_child_should_be_updated = OppositeChildShouldBeUpdated(
          is_split_numerical, *features_of_splits_going_up_from_original_leaf,
          inner_feature, *was_original_leaf_right_child_of_split,
          is_in_right_child);

      if (opposite_child_should_be_updated) {
        // if there is no monotone constraint on a split,
        // then there is no relationship between its left and right leaves' values
        if (monotone_type != 0) {
          // these variables correspond to the current split we encounter going
          // up the tree
657
658
          int left_child_idx = tree_->left_child(parent_idx);
          int right_child_idx = tree_->right_child(parent_idx);
659
660
661
662
663
664
665
666
667
668
669
          bool left_child_is_curr_idx = (left_child_idx == node_idx);
          int opposite_child_idx =
              (left_child_is_curr_idx) ? right_child_idx : left_child_idx;
          bool update_max_constraints_in_opposite_child_leaves =
              (monotone_type < 0) ? left_child_is_curr_idx
                                  : !left_child_is_curr_idx;

          // the opposite child needs to be updated
          // so the code needs to go down in the the opposite child
          // to see which leaves' constraints need to be updated
          GoDownToFindLeavesToUpdate(
670
              opposite_child_idx,
671
672
673
674
675
676
677
678
679
680
681
682
              *features_of_splits_going_up_from_original_leaf,
              *thresholds_of_splits_going_up_from_original_leaf,
              *was_original_leaf_right_child_of_split,
              update_max_constraints_in_opposite_child_leaves, split_feature,
              split_info, true, true, split_threshold, best_split_per_leaf);
        }

        // if opposite_child_should_be_updated, then it means the path to come up there was relevant,
        // i.e. that it will be helpful going down to determine which leaf
        // is actually contiguous to the original 2 leaves and should be updated
        // so the variables associated with the split need to be recorded
        was_original_leaf_right_child_of_split->push_back(
683
            tree_->right_child(parent_idx) == node_idx);
684
        thresholds_of_splits_going_up_from_original_leaf->push_back(
685
            tree_->threshold_in_bin(parent_idx));
686
        features_of_splits_going_up_from_original_leaf->push_back(
687
            tree_->split_feature_inner(parent_idx));
688
689
690
691
      }

      // since current node is not the root, keep going up
      GoUpToFindLeavesToUpdate(
692
          parent_idx, features_of_splits_going_up_from_original_leaf,
693
694
695
696
697
698
699
          thresholds_of_splits_going_up_from_original_leaf,
          was_original_leaf_right_child_of_split, split_feature, split_info,
          split_threshold, best_split_per_leaf);
    }
  }

  void GoDownToFindLeavesToUpdate(
700
      int node_idx,
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
      const std::vector<int>& features_of_splits_going_up_from_original_leaf,
      const std::vector<uint32_t>&
          thresholds_of_splits_going_up_from_original_leaf,
      const std::vector<bool>& was_original_leaf_right_child_of_split,
      bool update_max_constraints, int split_feature,
      const SplitInfo& split_info, bool use_left_leaf, bool use_right_leaf,
      uint32_t split_threshold,
      const std::vector<SplitInfo>& best_split_per_leaf) {
    // if leaf
    if (node_idx < 0) {
      int leaf_idx = ~node_idx;

      // splits that are not to be used shall not be updated,
      // included leaf at max depth
      if (best_split_per_leaf[leaf_idx].gain == kMinScore) {
        return;
      }

      std::pair<double, double> min_max_constraints;
      bool something_changed = false;
      // if the current leaf is contiguous with both the new right leaf and the new left leaf
      // then it may need to be greater than the max of the 2 or smaller than the min of the 2
      // otherwise, if the current leaf is contiguous with only one of the 2 new leaves,
      // then it may need to be greater or smaller than it
      if (use_right_leaf && use_left_leaf) {
        min_max_constraints =
            std::minmax(split_info.right_output, split_info.left_output);
      } else if (use_right_leaf && !use_left_leaf) {
        min_max_constraints = std::pair<double, double>(
            split_info.right_output, split_info.right_output);
      } else {
        min_max_constraints = std::pair<double, double>(split_info.left_output,
                                                        split_info.left_output);
      }

#ifdef DEBUG
      if (update_max_constraints) {
738
        CHECK_GE(min_max_constraints.first, tree_->LeafOutput(leaf_idx));
739
      } else {
740
        CHECK_LE(min_max_constraints.second, tree_->LeafOutput(leaf_idx));
741
742
743
744
745
      }
#endif
      // depending on which split made the current leaf and the original leaves contiguous,
      // either the min constraint or the max constraint of the current leaf need to be updated
      if (!update_max_constraints) {
746
        something_changed = entries_[leaf_idx]->UpdateMinAndReturnBoolIfChanged(
747
748
            min_max_constraints.second);
      } else {
749
        something_changed = entries_[leaf_idx]->UpdateMaxAndReturnBoolIfChanged(
750
751
752
753
754
755
756
757
758
759
760
            min_max_constraints.first);
      }
      // If constraints were not updated, then there is no need to update the leaf
      if (!something_changed) {
        return;
      }
      leaves_to_update_.push_back(leaf_idx);

    } else {  // if node
      // check if the children are contiguous with the original leaf
      std::pair<bool, bool> keep_going_left_right = ShouldKeepGoingLeftRight(
761
          node_idx, features_of_splits_going_up_from_original_leaf,
762
763
          thresholds_of_splits_going_up_from_original_leaf,
          was_original_leaf_right_child_of_split);
764
765
766
      int inner_feature = tree_->split_feature_inner(node_idx);
      uint32_t threshold = tree_->threshold_in_bin(node_idx);
      bool is_split_numerical = tree_->IsNumericalSplit(node_idx);
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
      bool use_left_leaf_for_update_right = true;
      bool use_right_leaf_for_update_left = true;
      // if the split is on the same feature (categorical variables not supported)
      // then depending on the threshold,
      // the current left child may not be contiguous with the original right leaf,
      // or the current right child may not be contiguous with the original left leaf
      if (is_split_numerical && inner_feature == split_feature) {
        if (threshold >= split_threshold) {
          use_left_leaf_for_update_right = false;
        }
        if (threshold <= split_threshold) {
          use_right_leaf_for_update_left = false;
        }
      }

      // go down left
      if (keep_going_left_right.first) {
        GoDownToFindLeavesToUpdate(
785
            tree_->left_child(node_idx),
786
787
788
789
790
791
792
793
794
795
            features_of_splits_going_up_from_original_leaf,
            thresholds_of_splits_going_up_from_original_leaf,
            was_original_leaf_right_child_of_split, update_max_constraints,
            split_feature, split_info, use_left_leaf,
            use_right_leaf_for_update_left && use_right_leaf, split_threshold,
            best_split_per_leaf);
      }
      // go down right
      if (keep_going_left_right.second) {
        GoDownToFindLeavesToUpdate(
796
            tree_->right_child(node_idx),
797
798
799
800
801
802
803
804
805
806
807
            features_of_splits_going_up_from_original_leaf,
            thresholds_of_splits_going_up_from_original_leaf,
            was_original_leaf_right_child_of_split, update_max_constraints,
            split_feature, split_info,
            use_left_leaf_for_update_right && use_left_leaf, use_right_leaf,
            split_threshold, best_split_per_leaf);
      }
    }
  }

  std::pair<bool, bool> ShouldKeepGoingLeftRight(
808
      int node_idx,
809
810
811
812
      const std::vector<int>& features_of_splits_going_up_from_original_leaf,
      const std::vector<uint32_t>&
          thresholds_of_splits_going_up_from_original_leaf,
      const std::vector<bool>& was_original_leaf_right_child_of_split) {
813
814
815
    int inner_feature = tree_->split_feature_inner(node_idx);
    uint32_t threshold = tree_->threshold_in_bin(node_idx);
    bool is_split_numerical = tree_->IsNumericalSplit(node_idx);
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848

    bool keep_going_right = true;
    bool keep_going_left = true;
    // left and right nodes are checked to find out if they are contiguous with
    // the original leaves if so the algorithm should keep going down these nodes
    // to update constraints
    if (is_split_numerical) {
      for (size_t i = 0;
           i < features_of_splits_going_up_from_original_leaf.size(); ++i) {
        if (features_of_splits_going_up_from_original_leaf[i] ==
            inner_feature) {
          if (threshold >=
                  thresholds_of_splits_going_up_from_original_leaf[i] &&
              !was_original_leaf_right_child_of_split[i]) {
            keep_going_right = false;
            if (!keep_going_left) {
              break;
            }
          }
          if (threshold <=
                  thresholds_of_splits_going_up_from_original_leaf[i] &&
              was_original_leaf_right_child_of_split[i]) {
            keep_going_left = false;
            if (!keep_going_right) {
              break;
            }
          }
        }
      }
    }
    return std::pair<bool, bool>(keep_going_left, keep_going_right);
  }

849
 protected:
850
851
852
853
854
855
856
857
  const Config* config_;
  std::vector<int> leaves_to_update_;
  // add parent node information
  std::vector<int> node_parent_;
  // Keeps track of the monotone splits above the leaf
  std::vector<bool> leaf_is_in_monotone_subtree_;
};

858
859
860
861
862
863
class AdvancedLeafConstraints : public IntermediateLeafConstraints {
 public:
  AdvancedLeafConstraints(const Config *config, int num_leaves,
                          int num_features)
      : IntermediateLeafConstraints(config, num_leaves) {
    for (int i = 0; i < num_leaves; ++i) {
864
      entries_[i].reset(new AdvancedConstraintEntry(num_features));
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
    }
  }

  // at any point in time, for an index i, the constraint constraint[i] has to
  // be valid on [threshold[i]: threshold[i + 1]) (or [threshold[i]: +inf) if i
  // is the last index of the array)
  void UpdateConstraints(FeatureMinOrMaxConstraints* feature_constraint,
                         double extremum, uint32_t it_start, uint32_t it_end,
                         bool use_max_operator, uint32_t last_threshold) {
    bool start_done = false;
    bool end_done = false;
    // previous constraint have to be tracked
    // for example when adding a constraints cstr2 on thresholds [1:2),
    // on an existing constraints cstr1 on thresholds [0, +inf),
    // the thresholds and constraints must become
880
    // [0, 1, 2] and [cstr1, cstr2, cstr1]
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
    // so since we loop through thresholds only once,
    // the previous constraint that still applies needs to be recorded
    double previous_constraint = use_max_operator
      ? -std::numeric_limits<double>::max()
      : std::numeric_limits<double>::max();
    double current_constraint;
    for (size_t i = 0; i < feature_constraint->thresholds.size(); ++i) {
      current_constraint = feature_constraint->constraints[i];
      // easy case when the thresholds match
      if (feature_constraint->thresholds[i] == it_start) {
        feature_constraint->constraints[i] =
            (use_max_operator)
                ? std::max(extremum, feature_constraint->constraints[i])
                : std::min(extremum, feature_constraint->constraints[i]);
        start_done = true;
      }
      if (feature_constraint->thresholds[i] > it_start) {
        // existing constraint is updated if there is a need for it
        if (feature_constraint->thresholds[i] < it_end) {
          feature_constraint->constraints[i] =
              (use_max_operator)
                  ? std::max(extremum, feature_constraint->constraints[i])
                  : std::min(extremum, feature_constraint->constraints[i]);
        }
        // when thresholds don't match, a new threshold
        // and a new constraint may need to be inserted
        if (!start_done) {
          start_done = true;
          if ((use_max_operator && extremum > previous_constraint) ||
              (!use_max_operator && extremum < previous_constraint)) {
            feature_constraint->constraints.insert(
                feature_constraint->constraints.begin() + i, extremum);
            feature_constraint->thresholds.insert(
                feature_constraint->thresholds.begin() + i, it_start);
            ++i;
          }
        }
      }
      // easy case when the end thresholds match
      if (feature_constraint->thresholds[i] == it_end) {
        end_done = true;
        break;
      }
      // if they don't then, the previous constraint needs to be added back
      // where the current one ends
      if (feature_constraint->thresholds[i] > it_end) {
        if (i != 0 &&
            previous_constraint != feature_constraint->constraints[i - 1]) {
          feature_constraint->constraints.insert(
              feature_constraint->constraints.begin() + i, previous_constraint);
          feature_constraint->thresholds.insert(
              feature_constraint->thresholds.begin() + i, it_end);
        }
        end_done = true;
        break;
      }
      // If 2 successive constraints are the same then the second one may as
      // well be deleted
      if (i != 0 && feature_constraint->constraints[i] ==
                        feature_constraint->constraints[i - 1]) {
        feature_constraint->constraints.erase(
            feature_constraint->constraints.begin() + i);
        feature_constraint->thresholds.erase(
            feature_constraint->thresholds.begin() + i);
        previous_constraint = current_constraint;
        --i;
      }
      previous_constraint = current_constraint;
    }
    // if the loop didn't get to an index greater than it_start, it needs to be
    // added at the end
    if (!start_done) {
      if ((use_max_operator &&
           extremum > feature_constraint->constraints.back()) ||
          (!use_max_operator &&
           extremum < feature_constraint->constraints.back())) {
        feature_constraint->constraints.push_back(extremum);
        feature_constraint->thresholds.push_back(it_start);
      } else {
        end_done = true;
      }
    }
    // if we didn't get to an index after it_end, then the previous constraint
    // needs to be set back, unless it_end goes up to the last bin of the feature
    if (!end_done && it_end != last_threshold &&
        previous_constraint != feature_constraint->constraints.back()) {
      feature_constraint->constraints.push_back(previous_constraint);
      feature_constraint->thresholds.push_back(it_end);
    }
  }

  // this function is called only when computing constraints when the monotone
  // precise mode is set to true
  // it makes sure that it is worth it to visit a branch, as it could
  // not contain any relevant constraint (for example if the a branch
  // with bigger values is also constraining the original leaf, then
  // it is useless to visit the branch with smaller values)
  std::pair<bool, bool>
  LeftRightContainsRelevantInformation(bool min_constraints_to_be_updated,
                                       int feature,
                                       bool split_feature_is_inner_feature) {
    if (split_feature_is_inner_feature) {
      return std::pair<bool, bool>(true, true);
    }
    int8_t monotone_type = config_->monotone_constraints[feature];
    if (monotone_type == 0) {
      return std::pair<bool, bool>(true, true);
    }
    if ((monotone_type == -1 && min_constraints_to_be_updated) ||
        (monotone_type == 1 && !min_constraints_to_be_updated)) {
      return std::pair<bool, bool>(true, false);
    } else {
      //    Same as
      //    if ((monotone_type == 1 && min_constraints_to_be_updated) ||
      //        (monotone_type == -1 && !min_constraints_to_be_updated))
      return std::pair<bool, bool>(false, true);
    }
  }

  // this function goes down in a subtree to find the
  // constraints that would apply on the original leaf
  void GoDownToFindConstrainingLeaves(
      int feature_for_constraint, int root_monotone_feature, int node_idx,
      bool min_constraints_to_be_updated, uint32_t it_start, uint32_t it_end,
      const std::vector<int> &features_of_splits_going_up_from_original_leaf,
      const std::vector<uint32_t> &
          thresholds_of_splits_going_up_from_original_leaf,
      const std::vector<bool> &was_original_leaf_right_child_of_split,
      FeatureMinOrMaxConstraints* feature_constraint, uint32_t last_threshold) {
    double extremum;
    // if leaf, then constraints need to be updated according to its value
    if (node_idx < 0) {
      extremum = tree_->LeafOutput(~node_idx);
#ifdef DEBUG
      CHECK(it_start < it_end);
#endif
      UpdateConstraints(feature_constraint, extremum, it_start, it_end,
                        min_constraints_to_be_updated, last_threshold);
    } else {  // if node, keep going down the tree
      // check if the children are contiguous to the original leaf and therefore
      // potentially constraining
      std::pair<bool, bool> keep_going_left_right = ShouldKeepGoingLeftRight(
          node_idx, features_of_splits_going_up_from_original_leaf,
          thresholds_of_splits_going_up_from_original_leaf,
          was_original_leaf_right_child_of_split);
      int inner_feature = tree_->split_feature_inner(node_idx);
      int feature = tree_->split_feature(node_idx);
      uint32_t threshold = tree_->threshold_in_bin(node_idx);

      bool split_feature_is_inner_feature =
          (inner_feature == feature_for_constraint);
      bool split_feature_is_monotone_feature =
          (root_monotone_feature == feature_for_constraint);
      // make sure that both children contain values that could
      // potentially help determine the true constraints for the original leaf
      std::pair<bool, bool> left_right_contain_relevant_information =
          LeftRightContainsRelevantInformation(
              min_constraints_to_be_updated, feature,
              split_feature_is_inner_feature &&
                  !split_feature_is_monotone_feature);
      // if both children are contiguous to the original leaf
      // but one contains values greater than the other
      // then no need to go down in both
      if (keep_going_left_right.first &&
          (left_right_contain_relevant_information.first ||
           !keep_going_left_right.second)) {
        // update thresholds based on going left
        uint32_t new_it_end = split_feature_is_inner_feature
                                  ? std::min(threshold + 1, it_end)
                                  : it_end;
        GoDownToFindConstrainingLeaves(
            feature_for_constraint, root_monotone_feature,
            tree_->left_child(node_idx), min_constraints_to_be_updated,
            it_start, new_it_end,
            features_of_splits_going_up_from_original_leaf,
            thresholds_of_splits_going_up_from_original_leaf,
            was_original_leaf_right_child_of_split, feature_constraint,
            last_threshold);
      }
      if (keep_going_left_right.second &&
          (left_right_contain_relevant_information.second ||
           !keep_going_left_right.first)) {
        // update thresholds based on going right
        uint32_t new_it_start = split_feature_is_inner_feature
                                    ? std::max(threshold + 1, it_start)
                                    : it_start;
        GoDownToFindConstrainingLeaves(
            feature_for_constraint, root_monotone_feature,
            tree_->right_child(node_idx), min_constraints_to_be_updated,
            new_it_start, it_end,
            features_of_splits_going_up_from_original_leaf,
            thresholds_of_splits_going_up_from_original_leaf,
            was_original_leaf_right_child_of_split, feature_constraint,
            last_threshold);
      }
    }
  }

  // this function is only used if the monotone precise mode is enabled
  // it recursively goes up the tree then down to find leaf that
  // are constraining the current leaf
  void GoUpToFindConstrainingLeaves(
      int feature_for_constraint, int node_idx,
      std::vector<int>* features_of_splits_going_up_from_original_leaf,
      std::vector<uint32_t>* thresholds_of_splits_going_up_from_original_leaf,
      std::vector<bool>* was_original_leaf_right_child_of_split,
      FeatureMinOrMaxConstraints* feature_constraint,
      bool min_constraints_to_be_updated, uint32_t it_start, uint32_t it_end,
      uint32_t last_threshold) final {
    int parent_idx =
        (node_idx < 0) ? tree_->leaf_parent(~node_idx) : node_parent_[node_idx];
    // if not at the root
    if (parent_idx != -1) {
      int inner_feature = tree_->split_feature_inner(parent_idx);
      int feature = tree_->split_feature(parent_idx);
      int8_t monotone_type = config_->monotone_constraints[feature];
      bool is_in_right_child = tree_->right_child(parent_idx) == node_idx;
      bool is_split_numerical = tree_->IsNumericalSplit(parent_idx);
      uint32_t threshold = tree_->threshold_in_bin(parent_idx);

      // by going up, more information about the position of the
1102
      // original leaf are gathered so the starting and ending
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
      // thresholds can be updated, which will save some time later
      if ((feature_for_constraint == inner_feature) && is_split_numerical) {
        if (is_in_right_child) {
          it_start = std::max(threshold, it_start);
        } else {
          it_end = std::min(threshold + 1, it_end);
        }
#ifdef DEBUG
        CHECK(it_start < it_end);
#endif
      }

      // this is just an optimisation not to waste time going down in subtrees
      // where there won't be any new constraining leaf
      bool opposite_child_necessary_to_update_constraints =
          OppositeChildShouldBeUpdated(
              is_split_numerical,
              *features_of_splits_going_up_from_original_leaf, inner_feature,
              *was_original_leaf_right_child_of_split, is_in_right_child);

      if (opposite_child_necessary_to_update_constraints) {
        // if there is no monotone constraint on a split,
        // then there is no relationship between its left and right leaves'
        // values
        if (monotone_type != 0) {
          int left_child_idx = tree_->left_child(parent_idx);
          int right_child_idx = tree_->right_child(parent_idx);
          bool left_child_is_curr_idx = (left_child_idx == node_idx);

          bool update_min_constraints_in_curr_child_leaf =
              (monotone_type < 0) ? left_child_is_curr_idx
                                  : !left_child_is_curr_idx;
          if (update_min_constraints_in_curr_child_leaf ==
              min_constraints_to_be_updated) {
            int opposite_child_idx =
                (left_child_is_curr_idx) ? right_child_idx : left_child_idx;

            // go down in the opposite branch to find potential
            // constraining leaves
            GoDownToFindConstrainingLeaves(
                feature_for_constraint, inner_feature, opposite_child_idx,
                min_constraints_to_be_updated, it_start, it_end,
                *features_of_splits_going_up_from_original_leaf,
                *thresholds_of_splits_going_up_from_original_leaf,
                *was_original_leaf_right_child_of_split, feature_constraint,
                last_threshold);
          }
        }
        // if opposite_child_should_be_updated, then it means the path to come
        // up there was relevant,
        // i.e. that it will be helpful going down to determine which leaf
        // is actually contiguous to the original leaf and constraining
        // so the variables associated with the split need to be recorded
        was_original_leaf_right_child_of_split->push_back(is_in_right_child);
        thresholds_of_splits_going_up_from_original_leaf->push_back(threshold);
        features_of_splits_going_up_from_original_leaf->push_back(inner_feature);
      }

      // since current node is not the root, keep going up
      if (parent_idx != 0) {
        GoUpToFindConstrainingLeaves(
            feature_for_constraint, parent_idx,
            features_of_splits_going_up_from_original_leaf,
            thresholds_of_splits_going_up_from_original_leaf,
            was_original_leaf_right_child_of_split, feature_constraint,
            min_constraints_to_be_updated, it_start, it_end, last_threshold);
      }
    }
  }
};

1174
LeafConstraintsBase* LeafConstraintsBase::Create(const Config* config,
1175
                                                 int num_leaves, int num_features) {
1176
1177
1178
  if (config->monotone_constraints_method == "intermediate") {
    return new IntermediateLeafConstraints(config, num_leaves);
  }
1179
1180
1181
  if (config->monotone_constraints_method == "advanced") {
    return new AdvancedLeafConstraints(config, num_leaves, num_features);
  }
1182
1183
1184
  return new BasicLeafConstraints(num_leaves);
}

Nikita Titov's avatar
Nikita Titov committed
1185
}  // namespace LightGBM
1186
#endif  // LIGHTGBM_SRC_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_