monotone_constraints.hpp 1.81 KB
Newer Older
Nikita Titov's avatar
Nikita Titov committed
1
2
3
4
5
6
/*!
 * Copyright (c) 2020 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
#ifndef LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_
#define LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_
7

Nikita Titov's avatar
Nikita Titov committed
8
#include <limits>
9
10
#include <algorithm>
#include <cstdint>
Nikita Titov's avatar
Nikita Titov committed
11
#include <vector>
12
13
14
15
16
17
18

namespace LightGBM {

struct ConstraintEntry {
  double min = -std::numeric_limits<double>::max();
  double max = std::numeric_limits<double>::max();

Nikita Titov's avatar
Nikita Titov committed
19
  ConstraintEntry() {}
20
21
22
23
24
25
26
27
28
29
30
31
32
33

  void Reset() {
    min = -std::numeric_limits<double>::max();
    max = std::numeric_limits<double>::max();
  }

  void UpdateMin(double new_min) { min = std::max(new_min, min); }

  void UpdateMax(double new_max) { max = std::min(new_max, max); }
};

template <typename ConstraintEntry>
class LeafConstraints {
 public:
Nikita Titov's avatar
Nikita Titov committed
34
  explicit LeafConstraints(int num_leaves) : num_leaves_(num_leaves) {
35
36
    entries_.resize(num_leaves_);
  }
Nikita Titov's avatar
Nikita Titov committed
37

38
39
40
41
42
  void Reset() {
    for (auto& entry : entries_) {
      entry.Reset();
    }
  }
Nikita Titov's avatar
Nikita Titov committed
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  void UpdateConstraints(bool is_numerical_split, int leaf, int new_leaf,
                         int8_t monotone_type, double right_output,
                         double left_output) {
    entries_[new_leaf] = entries_[leaf];
    if (is_numerical_split) {
      double mid = (left_output + right_output) / 2.0f;
      if (monotone_type < 0) {
        entries_[leaf].UpdateMin(mid);
        entries_[new_leaf].UpdateMax(mid);
      } else if (monotone_type > 0) {
        entries_[leaf].UpdateMax(mid);
        entries_[new_leaf].UpdateMin(mid);
      }
    }
  }

  const ConstraintEntry& Get(int leaf_idx) const { return entries_[leaf_idx]; }

 private:
  int num_leaves_;
  std::vector<ConstraintEntry> entries_;
};

Nikita Titov's avatar
Nikita Titov committed
67
68
}  // namespace LightGBM
#endif  // LIGHTGBM_TREELEARNER_MONOTONE_CONSTRAINTS_HPP_