gbdt.h 4.99 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#ifndef LIGHTGBM_BOOSTING_GBDT_H_
#define LIGHTGBM_BOOSTING_GBDT_H_

#include <LightGBM/boosting.h>
#include "score_updater.hpp"

#include <cstdio>
#include <vector>
#include <string>

namespace LightGBM {
/*!
* \brief GBDT algorithm implementation. including Training, prediction, bagging.
*/
class GBDT: public Boosting {
public:
  /*!
  * \brief Constructor
  * \param config Config of GBDT
  */
  explicit GBDT(const BoostingConfig* config);
  /*!
  * \brief Destructor
  */
  ~GBDT();
  /*!
  * \brief Initial logic
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metrics
  * \param output_model_filename Filename of output model
  */
  void Init(const Dataset* train_data, const ObjectiveFunction* object_function,
                             const std::vector<const Metric*>& training_metrics,
                                              const char* output_model_filename)
                                                                       override;
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metrics for validation data
  */
  void AddDataset(const Dataset* valid_data,
       const std::vector<const Metric*>& valid_metrics) override;
  /*!
  * \brief one training iteration
  */
  void Train() override;
  /*!
  * \brief Predtion for one record, not use sigmoid
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
  double PredictRaw(const double * feature_values) const override;

  /*!
  * \brief Predtion for one record, will use sigmoid transform if needed
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
  double Predict(const double * feature_values) const override;
  /*!
  * \brief Serialize models by string
  * \return String output of tranined model
  */
  std::string ModelsToString() const override;
  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
  void ModelsFromString(const std::string& model_str, int num_used_model) override;
  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  inline int MaxFeatureIdx() const override { return max_feature_idx_; }
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
  inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); }

private:
  /*!
  * \brief Implement bagging logic
  * \param iter Current interation
  */
  void Bagging(int iter);
  /*!
  * \brief update score for out-of-bag data.
  * It is necessary for this update, since we may re-bagging data on training
  * \param tree Trained tree of this iteration
  */
  void UpdateScoreOutOfBag(const Tree* tree);
  /*!
  * \brief calculate the object function
  */
  void Boosting();
  /*!
  * \brief train one tree
  * \return Trained tree of this iteration
  */
  Tree* TrainOneTree();
  /*!
  * \brief update score after tree trained
  * \param tree Trained tree of this iteration
  */
  void UpdateScore(const Tree* tree);
  /*!
  * \brief Print Metric result of current iteration
  * \param iter Current interation
  */
  void OutputMetric(int iter);

  /*! \brief Pointer to training data */
  const Dataset* train_data_;
  /*! \brief Config of gbdt */
  const GBDTConfig* gbdt_config_;
  /*! \brief Tree learner, will use tihs class to learn trees */
  TreeLearner* tree_learner_;
  /*! \brief Objective function */
  const ObjectiveFunction* object_function_;
  /*! \brief Store and update traning data's score */
  ScoreUpdater* train_score_updater_;
  /*! \brief Metrics for training data */
  std::vector<const Metric*> training_metrics_;
  /*! \brief Store and update validation data's scores */
  std::vector<ScoreUpdater*> valid_score_updater_;
  /*! \brief Metric for validation data */
  std::vector<std::vector<const Metric*>> valid_metrics_;
  /*! \brief Trained models(trees) */
  std::vector<Tree*> models_;
  /*! \brief Max feature index of training data*/
  int max_feature_idx_;
  /*! \brief First order derivative of training data */
  score_t* gradients_;
  /*! \brief Secend order derivative of training data */
  score_t* hessians_;
  /*! \brief Store the data indices of out-of-bag */
  data_size_t* out_of_bag_data_indices_;
  /*! \brief Number of out-of-bag data */
  data_size_t out_of_bag_data_cnt_;
  /*! \brief Store the indices of in-bag data */
  data_size_t* bag_data_indices_;
  /*! \brief Number of in-bag data */
  data_size_t bag_data_cnt_;
  /*! \brief Number of traning data */
  data_size_t num_data_;
  /*! \brief Random generator, used for bagging */
  Random random_;
  /*! \brief The filename that the models will save to */
  FILE * output_model_file;
  /*!
  *   \brief Sigmoid parameter, used for prediction.
  *          if > 0 meas output score will transform by sigmoid function
  */
  double sigmoid_;
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
161
#endif   // LightGBM_BOOSTING_GBDT_H_