gbdt.h 5.84 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#ifndef LIGHTGBM_BOOSTING_GBDT_H_
#define LIGHTGBM_BOOSTING_GBDT_H_

#include <LightGBM/boosting.h>
#include "score_updater.hpp"

#include <cstdio>
#include <vector>
#include <string>

namespace LightGBM {
/*!
* \brief GBDT algorithm implementation. including Training, prediction, bagging.
*/
class GBDT: public Boosting {
public:
  /*!
  * \brief Constructor
  * \param config Config of GBDT
  */
  explicit GBDT(const BoostingConfig* config);
  /*!
  * \brief Destructor
  */
  ~GBDT();
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
27
  * \brief Initialization logic
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
32
33
34
35
36
37
38
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metrics
  * \param output_model_filename Filename of output model
  */
  void Init(const Dataset* train_data, const ObjectiveFunction* object_function,
                             const std::vector<const Metric*>& training_metrics,
                                              const char* output_model_filename)
                                                                       override;
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
39
40
41
  * \brief Adding a validation dataset
  * \param valid_data Validation dataset
  * \param valid_metrics Metrics for validation dataset
Guolin Ke's avatar
Guolin Ke committed
42
43
44
45
46
47
48
49
  */
  void AddDataset(const Dataset* valid_data,
       const std::vector<const Metric*>& valid_metrics) override;
  /*!
  * \brief one training iteration
  */
  void Train() override;
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
50
  * \brief Predtion for one record without sigmoid transformation
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
56
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
  double PredictRaw(const double * feature_values) const override;

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
57
  * \brief Predtion for one record with sigmoid transformation if enabled
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
  double Predict(const double * feature_values) const override;
wxchan's avatar
wxchan committed
62
63
64
65
66
67
68
69
  
  /*!
  * \brief Predtion for one record with leaf index
  * \param feature_values Feature value on this record
  * \return Predicted leaf index for this record
  */
 std::vector<int> PredictLeafIndex(const double* value) const override;
  
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  /*!
  * \brief Serialize models by string
  * \return String output of tranined model
  */
  std::string ModelsToString() const override;
  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
  void ModelsFromString(const std::string& model_str, int num_used_model) override;
  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  inline int MaxFeatureIdx() const override { return max_feature_idx_; }
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
90
91

  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  inline int LabelIdx() const override { return label_idx_; }

Guolin Ke's avatar
Guolin Ke committed
92
93
94
95
96
97
98
99
100
101
102
103
104
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
  inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); }

private:
  /*!
  * \brief Implement bagging logic
  * \param iter Current interation
  */
  void Bagging(int iter);
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
105
106
  * \brief updating score for out-of-bag data.
  *        Data should be update since we may re-bagging data on training
Guolin Ke's avatar
Guolin Ke committed
107
108
109
110
111
112
113
114
  * \param tree Trained tree of this iteration
  */
  void UpdateScoreOutOfBag(const Tree* tree);
  /*!
  * \brief calculate the object function
  */
  void Boosting();
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
115
  * \brief training one tree
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
  * \return Trained tree of this iteration
  */
  Tree* TrainOneTree();
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
120
  * \brief updating score after tree was trained
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
  * \param tree Trained tree of this iteration
  */
  void UpdateScore(const Tree* tree);
  /*!
Hui Xue's avatar
Hui Xue committed
125
  * \brief Print metric result of current iteration
Guolin Ke's avatar
Guolin Ke committed
126
127
  * \param iter Current interation
  */
wxchan's avatar
wxchan committed
128
  bool OutputMetric(int iter);
wxchan's avatar
wxchan committed
129
130
131
132
133
  /*!
  * \brief Calculate feature importances
  * \param last_iter Last tree use to calculate
  */
  void FeatureImportance(const int last_iter);
wxchan's avatar
wxchan committed
134
  
Guolin Ke's avatar
Guolin Ke committed
135
136
137
138
  /*! \brief Pointer to training data */
  const Dataset* train_data_;
  /*! \brief Config of gbdt */
  const GBDTConfig* gbdt_config_;
Hui Xue's avatar
Hui Xue committed
139
  /*! \brief Tree learner, will use this class to learn trees */
Guolin Ke's avatar
Guolin Ke committed
140
141
142
  TreeLearner* tree_learner_;
  /*! \brief Objective function */
  const ObjectiveFunction* object_function_;
Hui Xue's avatar
Hui Xue committed
143
  /*! \brief Store and update training data's score */
Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
148
149
150
  ScoreUpdater* train_score_updater_;
  /*! \brief Metrics for training data */
  std::vector<const Metric*> training_metrics_;
  /*! \brief Store and update validation data's scores */
  std::vector<ScoreUpdater*> valid_score_updater_;
  /*! \brief Metric for validation data */
  std::vector<std::vector<const Metric*>> valid_metrics_;
wxchan's avatar
wxchan committed
151
152
  /*! \brief Number of rounds for early stopping */
  int early_stopping_round_;
wxchan's avatar
wxchan committed
153
154
155
  /*! \brief Best score(s) for early stopping */
  std::vector<std::vector<int>> best_iter_;
  std::vector<std::vector<score_t>> best_score_;
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
  /*! \brief Trained models(trees) */
  std::vector<Tree*> models_;
  /*! \brief Max feature index of training data*/
  int max_feature_idx_;
  /*! \brief First order derivative of training data */
  score_t* gradients_;
  /*! \brief Secend order derivative of training data */
  score_t* hessians_;
  /*! \brief Store the data indices of out-of-bag */
  data_size_t* out_of_bag_data_indices_;
  /*! \brief Number of out-of-bag data */
  data_size_t out_of_bag_data_cnt_;
  /*! \brief Store the indices of in-bag data */
  data_size_t* bag_data_indices_;
  /*! \brief Number of in-bag data */
  data_size_t bag_data_cnt_;
  /*! \brief Number of traning data */
  data_size_t num_data_;
  /*! \brief Random generator, used for bagging */
  Random random_;
  /*! \brief The filename that the models will save to */
  FILE * output_model_file;
  /*!
  *   \brief Sigmoid parameter, used for prediction.
  *          if > 0 meas output score will transform by sigmoid function
  */
  double sigmoid_;
Guolin Ke's avatar
Guolin Ke committed
183
184
185

  /*! \brief Index of label column */
  data_size_t label_idx_;
Guolin Ke's avatar
Guolin Ke committed
186
187
188
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
189
#endif   // LightGBM_BOOSTING_GBDT_H_