monotone_constraints.hpp 18.4 KB
Newer Older
Nikita Titov's avatar
Nikita Titov committed
1
2
/*!
 * Copyright (c) 2020 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
Nikita Titov's avatar
Nikita Titov committed
5
6
7
 */
#ifndef LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_
#define LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_
8
9
10

#include <algorithm>
#include <cstdint>
11
12
#include <limits>
#include <utility>
Nikita Titov's avatar
Nikita Titov committed
13
#include <vector>
14

15
16
#include "split_info.hpp"

17
18
19
20
21
22
namespace LightGBM {

struct ConstraintEntry {
  double min = -std::numeric_limits<double>::max();
  double max = std::numeric_limits<double>::max();

Nikita Titov's avatar
Nikita Titov committed
23
  ConstraintEntry() {}
24
25
26
27
28
29
30
31
32

  void Reset() {
    min = -std::numeric_limits<double>::max();
    max = std::numeric_limits<double>::max();
  }

  void UpdateMin(double new_min) { min = std::max(new_min, min); }

  void UpdateMax(double new_max) { max = std::min(new_max, max); }
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

  bool UpdateMinAndReturnBoolIfChanged(double new_min) {
    if (new_min > min) {
      min = new_min;
      return true;
    }
    return false;
  }

  bool UpdateMaxAndReturnBoolIfChanged(double new_max) {
    if (new_max < max) {
      max = new_max;
      return true;
    }
    return false;
  }
};

class LeafConstraintsBase {
 public:
  virtual ~LeafConstraintsBase() {}
  virtual const ConstraintEntry& Get(int leaf_idx) const = 0;
  virtual void Reset() = 0;
  virtual void BeforeSplit(const Tree* tree, int leaf, int new_leaf,
                           int8_t monotone_type) = 0;
  virtual std::vector<int> Update(
      const Tree* tree, bool is_numerical_split,
      int leaf, int new_leaf, int8_t monotone_type, double right_output,
      double left_output, int split_feature, const SplitInfo& split_info,
      const std::vector<SplitInfo>& best_split_per_leaf) = 0;

  inline static LeafConstraintsBase* Create(const Config* config, int num_leaves);
65
66
};

67
class BasicLeafConstraints : public LeafConstraintsBase {
68
 public:
69
  explicit BasicLeafConstraints(int num_leaves) : num_leaves_(num_leaves) {
70
71
    entries_.resize(num_leaves_);
  }
Nikita Titov's avatar
Nikita Titov committed
72

73
  void Reset() override {
74
75
76
77
    for (auto& entry : entries_) {
      entry.Reset();
    }
  }
Nikita Titov's avatar
Nikita Titov committed
78

79
80
81
82
83
84
85
  void BeforeSplit(const Tree*, int, int, int8_t) override {}

  std::vector<int> Update(const Tree*,
                          bool is_numerical_split, int leaf, int new_leaf,
                          int8_t monotone_type, double right_output,
                          double left_output, int, const SplitInfo& ,
                          const std::vector<SplitInfo>&) override {
86
87
88
89
90
91
92
93
94
95
96
    entries_[new_leaf] = entries_[leaf];
    if (is_numerical_split) {
      double mid = (left_output + right_output) / 2.0f;
      if (monotone_type < 0) {
        entries_[leaf].UpdateMin(mid);
        entries_[new_leaf].UpdateMax(mid);
      } else if (monotone_type > 0) {
        entries_[leaf].UpdateMax(mid);
        entries_[new_leaf].UpdateMin(mid);
      }
    }
97
    return std::vector<int>();
98
99
  }

100
  const ConstraintEntry& Get(int leaf_idx) const override { return entries_[leaf_idx]; }
101

102
 protected:
103
104
105
106
  int num_leaves_;
  std::vector<ConstraintEntry> entries_;
};

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
class IntermediateLeafConstraints : public BasicLeafConstraints {
 public:
  explicit IntermediateLeafConstraints(const Config* config, int num_leaves)
      : BasicLeafConstraints(num_leaves), config_(config) {
    leaf_is_in_monotone_subtree_.resize(num_leaves_, false);
    node_parent_.resize(num_leaves_ - 1, -1);
    leaves_to_update_.reserve(num_leaves_);
  }

  void Reset() override {
    BasicLeafConstraints::Reset();
    std::fill_n(leaf_is_in_monotone_subtree_.begin(), num_leaves_, false);
    std::fill_n(node_parent_.begin(), num_leaves_ - 1, -1);
    leaves_to_update_.clear();
  }

  void BeforeSplit(const Tree* tree, int leaf, int new_leaf,
                   int8_t monotone_type) override {
    if (monotone_type != 0 || leaf_is_in_monotone_subtree_[leaf]) {
      leaf_is_in_monotone_subtree_[leaf] = true;
      leaf_is_in_monotone_subtree_[new_leaf] = true;
    }
#ifdef DEBUG
    CHECK_GE(new_leaf - 1, 0);
    CHECK_LT(static_cast<size_t>(new_leaf - 1), node_parent_.size());
#endif
    node_parent_[new_leaf - 1] = tree->leaf_parent(leaf);
  }

  void UpdateConstraintsWithOutputs(bool is_numerical_split, int leaf,
                                    int new_leaf, int8_t monotone_type,
                                    double right_output, double left_output) {
    entries_[new_leaf] = entries_[leaf];
    if (is_numerical_split) {
      if (monotone_type < 0) {
        entries_[leaf].UpdateMin(right_output);
        entries_[new_leaf].UpdateMax(left_output);
      } else if (monotone_type > 0) {
        entries_[leaf].UpdateMax(right_output);
        entries_[new_leaf].UpdateMin(left_output);
      }
    }
  }

  std::vector<int> Update(const Tree* tree, bool is_numerical_split, int leaf,
                          int new_leaf, int8_t monotone_type,
                          double right_output, double left_output,
                          int split_feature, const SplitInfo& split_info,
                          const std::vector<SplitInfo>& best_split_per_leaf) override {
    leaves_to_update_.clear();
    if (leaf_is_in_monotone_subtree_[leaf]) {
      UpdateConstraintsWithOutputs(is_numerical_split, leaf, new_leaf,
                                   monotone_type, right_output, left_output);

      // Initialize variables to store information while going up the tree
      int depth = tree->leaf_depth(new_leaf) - 1;

      std::vector<int> features_of_splits_going_up_from_original_leaf;
      std::vector<uint32_t> thresholds_of_splits_going_up_from_original_leaf;
      std::vector<bool> was_original_leaf_right_child_of_split;

      features_of_splits_going_up_from_original_leaf.reserve(depth);
      thresholds_of_splits_going_up_from_original_leaf.reserve(depth);
      was_original_leaf_right_child_of_split.reserve(depth);

      GoUpToFindLeavesToUpdate(tree, tree->leaf_parent(new_leaf),
                               &features_of_splits_going_up_from_original_leaf,
                               &thresholds_of_splits_going_up_from_original_leaf,
                               &was_original_leaf_right_child_of_split,
                               split_feature, split_info, split_info.threshold,
                               best_split_per_leaf);
    }
    return leaves_to_update_;
  }

  bool OppositeChildShouldBeUpdated(
      bool is_split_numerical,
      const std::vector<int>& features_of_splits_going_up_from_original_leaf,
      int inner_feature,
      const std::vector<bool>& was_original_leaf_right_child_of_split,
      bool is_in_right_child) {
    bool opposite_child_should_be_updated = true;

    // if the split is categorical, it is not handled by this optimisation,
    // so the code will have to go down in the other child subtree to see if
    // there are leaves to update
    // even though it may sometimes be unnecessary
    if (is_split_numerical) {
      // only branches containing leaves that are contiguous to the original
      // leaf need to be updated
      // therefore, for the same feature, there is no use going down from the
      // second time going up on the right (or on the left)
      for (size_t split_idx = 0;
           split_idx < features_of_splits_going_up_from_original_leaf.size();
           ++split_idx) {
        if (features_of_splits_going_up_from_original_leaf[split_idx] ==
                inner_feature &&
            (was_original_leaf_right_child_of_split[split_idx] ==
             is_in_right_child)) {
          opposite_child_should_be_updated = false;
          break;
        }
      }
    }
    return opposite_child_should_be_updated;
  }

  // Recursive function that goes up the tree, and then down to find leaves that
  // have constraints to be updated
  void GoUpToFindLeavesToUpdate(
      const Tree* tree, int node_idx,
      std::vector<int>* features_of_splits_going_up_from_original_leaf,
      std::vector<uint32_t>* thresholds_of_splits_going_up_from_original_leaf,
      std::vector<bool>* was_original_leaf_right_child_of_split,
      int split_feature, const SplitInfo& split_info, uint32_t split_threshold,
      const std::vector<SplitInfo>& best_split_per_leaf) {
#ifdef DEBUG
    CHECK_GE(node_idx, 0);
    CHECK_LT(static_cast<size_t>(node_idx), node_parent_.size());
#endif
    int parent_idx = node_parent_[node_idx];
    // if not at the root
    if (parent_idx != -1) {
      int inner_feature = tree->split_feature_inner(parent_idx);
      int feature = tree->split_feature(parent_idx);
      int8_t monotone_type = config_->monotone_constraints[feature];
      bool is_in_right_child = tree->right_child(parent_idx) == node_idx;
      bool is_split_numerical = tree->IsNumericalSplit(node_idx);

      // this is just an optimisation not to waste time going down in subtrees
      // where there won't be any leaf to update
      bool opposite_child_should_be_updated = OppositeChildShouldBeUpdated(
          is_split_numerical, *features_of_splits_going_up_from_original_leaf,
          inner_feature, *was_original_leaf_right_child_of_split,
          is_in_right_child);

      if (opposite_child_should_be_updated) {
        // if there is no monotone constraint on a split,
        // then there is no relationship between its left and right leaves' values
        if (monotone_type != 0) {
          // these variables correspond to the current split we encounter going
          // up the tree
          int left_child_idx = tree->left_child(parent_idx);
          int right_child_idx = tree->right_child(parent_idx);
          bool left_child_is_curr_idx = (left_child_idx == node_idx);
          int opposite_child_idx =
              (left_child_is_curr_idx) ? right_child_idx : left_child_idx;
          bool update_max_constraints_in_opposite_child_leaves =
              (monotone_type < 0) ? left_child_is_curr_idx
                                  : !left_child_is_curr_idx;

          // the opposite child needs to be updated
          // so the code needs to go down in the the opposite child
          // to see which leaves' constraints need to be updated
          GoDownToFindLeavesToUpdate(
              tree, opposite_child_idx,
              *features_of_splits_going_up_from_original_leaf,
              *thresholds_of_splits_going_up_from_original_leaf,
              *was_original_leaf_right_child_of_split,
              update_max_constraints_in_opposite_child_leaves, split_feature,
              split_info, true, true, split_threshold, best_split_per_leaf);
        }

        // if opposite_child_should_be_updated, then it means the path to come up there was relevant,
        // i.e. that it will be helpful going down to determine which leaf
        // is actually contiguous to the original 2 leaves and should be updated
        // so the variables associated with the split need to be recorded
        was_original_leaf_right_child_of_split->push_back(
            tree->right_child(parent_idx) == node_idx);
        thresholds_of_splits_going_up_from_original_leaf->push_back(
            tree->threshold_in_bin(parent_idx));
        features_of_splits_going_up_from_original_leaf->push_back(
            tree->split_feature_inner(parent_idx));
      }

      // since current node is not the root, keep going up
      GoUpToFindLeavesToUpdate(
          tree, parent_idx, features_of_splits_going_up_from_original_leaf,
          thresholds_of_splits_going_up_from_original_leaf,
          was_original_leaf_right_child_of_split, split_feature, split_info,
          split_threshold, best_split_per_leaf);
    }
  }

  void GoDownToFindLeavesToUpdate(
      const Tree* tree, int node_idx,
      const std::vector<int>& features_of_splits_going_up_from_original_leaf,
      const std::vector<uint32_t>&
          thresholds_of_splits_going_up_from_original_leaf,
      const std::vector<bool>& was_original_leaf_right_child_of_split,
      bool update_max_constraints, int split_feature,
      const SplitInfo& split_info, bool use_left_leaf, bool use_right_leaf,
      uint32_t split_threshold,
      const std::vector<SplitInfo>& best_split_per_leaf) {
    // if leaf
    if (node_idx < 0) {
      int leaf_idx = ~node_idx;

      // splits that are not to be used shall not be updated,
      // included leaf at max depth
      if (best_split_per_leaf[leaf_idx].gain == kMinScore) {
        return;
      }

      std::pair<double, double> min_max_constraints;
      bool something_changed = false;
      // if the current leaf is contiguous with both the new right leaf and the new left leaf
      // then it may need to be greater than the max of the 2 or smaller than the min of the 2
      // otherwise, if the current leaf is contiguous with only one of the 2 new leaves,
      // then it may need to be greater or smaller than it
      if (use_right_leaf && use_left_leaf) {
        min_max_constraints =
            std::minmax(split_info.right_output, split_info.left_output);
      } else if (use_right_leaf && !use_left_leaf) {
        min_max_constraints = std::pair<double, double>(
            split_info.right_output, split_info.right_output);
      } else {
        min_max_constraints = std::pair<double, double>(split_info.left_output,
                                                        split_info.left_output);
      }

#ifdef DEBUG
      if (update_max_constraints) {
        CHECK_GE(min_max_constraints.first, tree->LeafOutput(leaf_idx));
      } else {
        CHECK_LE(min_max_constraints.second, tree->LeafOutput(leaf_idx));
      }
#endif
      // depending on which split made the current leaf and the original leaves contiguous,
      // either the min constraint or the max constraint of the current leaf need to be updated
      if (!update_max_constraints) {
        something_changed = entries_[leaf_idx].UpdateMinAndReturnBoolIfChanged(
            min_max_constraints.second);
      } else {
        something_changed = entries_[leaf_idx].UpdateMaxAndReturnBoolIfChanged(
            min_max_constraints.first);
      }
      // If constraints were not updated, then there is no need to update the leaf
      if (!something_changed) {
        return;
      }
      leaves_to_update_.push_back(leaf_idx);

    } else {  // if node
      // check if the children are contiguous with the original leaf
      std::pair<bool, bool> keep_going_left_right = ShouldKeepGoingLeftRight(
          tree, node_idx, features_of_splits_going_up_from_original_leaf,
          thresholds_of_splits_going_up_from_original_leaf,
          was_original_leaf_right_child_of_split);
      int inner_feature = tree->split_feature_inner(node_idx);
      uint32_t threshold = tree->threshold_in_bin(node_idx);
      bool is_split_numerical = tree->IsNumericalSplit(node_idx);
      bool use_left_leaf_for_update_right = true;
      bool use_right_leaf_for_update_left = true;
      // if the split is on the same feature (categorical variables not supported)
      // then depending on the threshold,
      // the current left child may not be contiguous with the original right leaf,
      // or the current right child may not be contiguous with the original left leaf
      if (is_split_numerical && inner_feature == split_feature) {
        if (threshold >= split_threshold) {
          use_left_leaf_for_update_right = false;
        }
        if (threshold <= split_threshold) {
          use_right_leaf_for_update_left = false;
        }
      }

      // go down left
      if (keep_going_left_right.first) {
        GoDownToFindLeavesToUpdate(
            tree, tree->left_child(node_idx),
            features_of_splits_going_up_from_original_leaf,
            thresholds_of_splits_going_up_from_original_leaf,
            was_original_leaf_right_child_of_split, update_max_constraints,
            split_feature, split_info, use_left_leaf,
            use_right_leaf_for_update_left && use_right_leaf, split_threshold,
            best_split_per_leaf);
      }
      // go down right
      if (keep_going_left_right.second) {
        GoDownToFindLeavesToUpdate(
            tree, tree->right_child(node_idx),
            features_of_splits_going_up_from_original_leaf,
            thresholds_of_splits_going_up_from_original_leaf,
            was_original_leaf_right_child_of_split, update_max_constraints,
            split_feature, split_info,
            use_left_leaf_for_update_right && use_left_leaf, use_right_leaf,
            split_threshold, best_split_per_leaf);
      }
    }
  }

  std::pair<bool, bool> ShouldKeepGoingLeftRight(
      const Tree* tree, int node_idx,
      const std::vector<int>& features_of_splits_going_up_from_original_leaf,
      const std::vector<uint32_t>&
          thresholds_of_splits_going_up_from_original_leaf,
      const std::vector<bool>& was_original_leaf_right_child_of_split) {
    int inner_feature = tree->split_feature_inner(node_idx);
    uint32_t threshold = tree->threshold_in_bin(node_idx);
    bool is_split_numerical = tree->IsNumericalSplit(node_idx);

    bool keep_going_right = true;
    bool keep_going_left = true;
    // left and right nodes are checked to find out if they are contiguous with
    // the original leaves if so the algorithm should keep going down these nodes
    // to update constraints
    if (is_split_numerical) {
      for (size_t i = 0;
           i < features_of_splits_going_up_from_original_leaf.size(); ++i) {
        if (features_of_splits_going_up_from_original_leaf[i] ==
            inner_feature) {
          if (threshold >=
                  thresholds_of_splits_going_up_from_original_leaf[i] &&
              !was_original_leaf_right_child_of_split[i]) {
            keep_going_right = false;
            if (!keep_going_left) {
              break;
            }
          }
          if (threshold <=
                  thresholds_of_splits_going_up_from_original_leaf[i] &&
              was_original_leaf_right_child_of_split[i]) {
            keep_going_left = false;
            if (!keep_going_right) {
              break;
            }
          }
        }
      }
    }
    return std::pair<bool, bool>(keep_going_left, keep_going_right);
  }

 private:
  const Config* config_;
  std::vector<int> leaves_to_update_;
  // add parent node information
  std::vector<int> node_parent_;
  // Keeps track of the monotone splits above the leaf
  std::vector<bool> leaf_is_in_monotone_subtree_;
};

LeafConstraintsBase* LeafConstraintsBase::Create(const Config* config,
                                                 int num_leaves) {
  if (config->monotone_constraints_method == "intermediate") {
    return new IntermediateLeafConstraints(config, num_leaves);
  }
  return new BasicLeafConstraints(num_leaves);
}

Nikita Titov's avatar
Nikita Titov committed
458
459
}  // namespace LightGBM
#endif  // LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_