"include/vscode:/vscode.git/clone" did not exist on "060c681dbf79ab684c66a358156dfd96e47a16b9"
gradient_discretizer.hpp 3.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/*!
 * Copyright (c) 2022 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
 */
#ifndef LIGHTGBM_TREE_LEARNER_GRADIENT_DISCRETIZER_HPP_
#define LIGHTGBM_TREE_LEARNER_GRADIENT_DISCRETIZER_HPP_

#include <LightGBM/bin.h>
#include <LightGBM/meta.h>
#include <LightGBM/tree.h>
#include <LightGBM/utils/threading.h>

#include <random>
#include <vector>

#include "data_partition.hpp"
#include "feature_histogram.hpp"

namespace LightGBM {

class GradientDiscretizer {
 public:
  GradientDiscretizer(int num_grad_quant_bins, int num_trees, int random_seed, bool is_constant_hessian, const bool stochastic_rounding) {
    num_grad_quant_bins_ = num_grad_quant_bins;
    iter_ = 0;
    num_trees_ = num_trees;
    random_seed_ = random_seed;
    is_constant_hessian_ = is_constant_hessian;
    stochastic_rounding_ = stochastic_rounding;
  }

33
  virtual ~GradientDiscretizer() {}
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

  virtual void DiscretizeGradients(
    const data_size_t num_data,
    const score_t* input_gradients,
    const score_t* input_hessians);

  virtual const int8_t* discretized_gradients_and_hessians() const {
    return discretized_gradients_and_hessians_vector_.data();
  }

  virtual double grad_scale() const {
    return gradient_scale_;
  }

  virtual double hess_scale() const {
    return hessian_scale_;
  }

  virtual void Init(
    const data_size_t num_data, const int num_leaves,
    const int num_features, const Dataset* train_data);

  template <bool IS_GLOBAL>
  void SetNumBitsInHistogramBin(
    const int left_leaf_index, const int right_leaf_index,
    const data_size_t num_data_in_left_leaf, const data_size_t num_data_in_right_leaf);

  template <bool IS_GLOBAL>
  int8_t GetHistBitsInLeaf(const int leaf_index) {
    if (IS_GLOBAL) {
      return global_leaf_num_bits_in_histogram_bin_[leaf_index];
    } else {
      return leaf_num_bits_in_histogram_bin_[leaf_index];
    }
  }

  template <bool IS_GLOBAL>
  int8_t GetHistBitsInNode(const int node_index) {
    if (IS_GLOBAL) {
      return global_node_num_bits_in_histogram_bin_[node_index];
    } else {
      return node_num_bits_in_histogram_bin_[node_index];
    }
  }

  int8_t* ordered_int_gradients_and_hessians() {
    return ordered_int_gradients_and_hessians_.data();
  }

  void RenewIntGradTreeOutput(
    Tree* tree, const Config* config, const DataPartition* data_partition,
    const score_t* gradients, const score_t* hessians,
    const std::function<data_size_t(int)>& leaf_index_to_global_num_data);

  int32_t* GetChangeHistBitsBuffer(const int feature_index) {
    return change_hist_bits_buffer_[feature_index].data();
  }

 protected:
  int num_grad_quant_bins_;
  int iter_;
  int num_trees_;
  int random_seed_;
  bool stochastic_rounding_;

  std::vector<double> gradient_random_values_;
  std::vector<double> hessian_random_values_;
  std::mt19937 random_values_use_start_eng_;
  std::uniform_int_distribution<data_size_t> random_values_use_start_dist_;
  std::vector<int8_t> discretized_gradients_and_hessians_vector_;
  std::vector<int8_t> ordered_int_gradients_and_hessians_;

  double max_gradient_abs_;
  double max_hessian_abs_;

  double gradient_scale_;
  double hessian_scale_;
  double inverse_gradient_scale_;
  double inverse_hessian_scale_;

  bool is_constant_hessian_;
  int num_leaves_;

  std::vector<int8_t> leaf_num_bits_in_histogram_bin_;
  std::vector<int8_t> node_num_bits_in_histogram_bin_;
  std::vector<int8_t> global_leaf_num_bits_in_histogram_bin_;
  std::vector<int8_t> global_node_num_bits_in_histogram_bin_;

  std::vector<double> leaf_grad_hess_stats_;
  std::vector<std::vector<int32_t>> change_hist_bits_buffer_;
};

}  // namespace LightGBM

#endif  // LIGHTGBM_TREE_LEARNER_GRADIENT_DISCRETIZER_HPP_