gbdt.h 7.01 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#ifndef LIGHTGBM_BOOSTING_GBDT_H_
#define LIGHTGBM_BOOSTING_GBDT_H_

#include <LightGBM/boosting.h>
#include "score_updater.hpp"

#include <cstdio>
#include <vector>
#include <string>
10
#include <fstream>
Guolin Ke's avatar
Guolin Ke committed
11
12
13
14
15
16
17
18
19
20

namespace LightGBM {
/*!
* \brief GBDT algorithm implementation. including Training, prediction, bagging.
*/
class GBDT: public Boosting {
public:
  /*!
  * \brief Constructor
  */
21
  GBDT();
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
  /*!
  * \brief Destructor
  */
  ~GBDT();
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
27
  * \brief Initialization logic
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
32
33
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metrics
  * \param output_model_filename Filename of output model
  */
34
35
  void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function,
                             const std::vector<const Metric*>& training_metrics)
Guolin Ke's avatar
Guolin Ke committed
36
37
                                                                       override;
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
38
39
40
  * \brief Adding a validation dataset
  * \param valid_data Validation dataset
  * \param valid_metrics Metrics for validation dataset
Guolin Ke's avatar
Guolin Ke committed
41
42
43
44
45
46
  */
  void AddDataset(const Dataset* valid_data,
       const std::vector<const Metric*>& valid_metrics) override;
  /*!
  * \brief one training iteration
  */
47
48
49
50
51
52
53
54
  bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override;

  /*! \brief Get eval result */
  std::vector<std::string> EvalCurrent(bool is_eval_train) const override;

  /*! \brief Get prediction result */
  const std::vector<const score_t*> PredictCurrent(bool is_predict_train) const override;

Guolin Ke's avatar
Guolin Ke committed
55
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
56
  * \brief Predtion for one record without sigmoid transformation
Guolin Ke's avatar
Guolin Ke committed
57
  * \param feature_values Feature value on this record
58
  * \param num_used_model Number of used model
Guolin Ke's avatar
Guolin Ke committed
59
60
  * \return Prediction result for this record
  */
61
  float PredictRaw(const float* feature_values, int num_used_model) const override;
Guolin Ke's avatar
Guolin Ke committed
62
63

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
64
  * \brief Predtion for one record with sigmoid transformation if enabled
Guolin Ke's avatar
Guolin Ke committed
65
  * \param feature_values Feature value on this record
66
  * \param num_used_model Number of used model
Guolin Ke's avatar
Guolin Ke committed
67
68
  * \return Prediction result for this record
  */
69
  float Predict(const float* feature_values, int num_used_model) const override;
wxchan's avatar
wxchan committed
70
  
71
72
73
74
75
76
77
  /*!
  * \brief Predtion for multiclass classification
  * \param feature_values Feature value on this record
  * \return Prediction result, num_class numbers per line
  */
  std::vector<float> PredictMulticlass(const float* value, int num_used_model) const override;
  
wxchan's avatar
wxchan committed
78
79
80
  /*!
  * \brief Predtion for one record with leaf index
  * \param feature_values Feature value on this record
81
  * \param num_used_model Number of used model
wxchan's avatar
wxchan committed
82
83
  * \return Predicted leaf index for this record
  */
84
  std::vector<int> PredictLeafIndex(const float* value, int num_used_model) const override;
wxchan's avatar
wxchan committed
85
  
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
  /*!
  * \brief Serialize models by string
  * \return String output of tranined model
  */
90
  void SaveModelToFile(bool is_finish, const char* filename) override;
Guolin Ke's avatar
Guolin Ke committed
91
92
93
  /*!
  * \brief Restore from a serialized string
  */
94
  void ModelsFromString(const std::string& model_str) override;
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  inline int MaxFeatureIdx() const override { return max_feature_idx_; }
Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
104
105
106

  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  inline int LabelIdx() const override { return label_idx_; }

Guolin Ke's avatar
Guolin Ke committed
107
108
109
110
111
112
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
  inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); }

113
114
115
116
117
118
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
  inline int NumberOfClass() const override { return num_class_; }
  
119
120
121
122
123
  /*!
  * \brief Get Type name of this boosting object
  */
  const char* Name() const override { return "gbdt"; }

Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
private:
  /*!
  * \brief Implement bagging logic
  * \param iter Current interation
128
  * \param curr_class Current class for multiclass training
Guolin Ke's avatar
Guolin Ke committed
129
  */
130
  void Bagging(int iter, const int curr_class);
Guolin Ke's avatar
Guolin Ke committed
131
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
132
133
  * \brief updating score for out-of-bag data.
  *        Data should be update since we may re-bagging data on training
Guolin Ke's avatar
Guolin Ke committed
134
  * \param tree Trained tree of this iteration
135
  * \param curr_class Current class for multiclass training
Guolin Ke's avatar
Guolin Ke committed
136
  */
137
  void UpdateScoreOutOfBag(const Tree* tree, const int curr_class);
Guolin Ke's avatar
Guolin Ke committed
138
139
140
141
142
  /*!
  * \brief calculate the object function
  */
  void Boosting();
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
143
  * \brief updating score after tree was trained
Guolin Ke's avatar
Guolin Ke committed
144
  * \param tree Trained tree of this iteration
145
  * \param curr_class Current class for multiclass training
Guolin Ke's avatar
Guolin Ke committed
146
  */
147
  void UpdateScore(const Tree* tree, const int curr_class);
Guolin Ke's avatar
Guolin Ke committed
148
  /*!
Hui Xue's avatar
Hui Xue committed
149
  * \brief Print metric result of current iteration
Guolin Ke's avatar
Guolin Ke committed
150
151
  * \param iter Current interation
  */
wxchan's avatar
wxchan committed
152
  bool OutputMetric(int iter);
wxchan's avatar
wxchan committed
153
154
155
156
  /*!
  * \brief Calculate feature importances
  * \param last_iter Last tree use to calculate
  */
157
158
159
  std::string FeatureImportance() const;
  /*! \brief current iteration */
  int iter_;
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
  /*! \brief Pointer to training data */
  const Dataset* train_data_;
  /*! \brief Config of gbdt */
  const GBDTConfig* gbdt_config_;
Hui Xue's avatar
Hui Xue committed
164
  /*! \brief Tree learner, will use this class to learn trees */
165
  std::vector<TreeLearner*> tree_learner_;
Guolin Ke's avatar
Guolin Ke committed
166
167
  /*! \brief Objective function */
  const ObjectiveFunction* object_function_;
Hui Xue's avatar
Hui Xue committed
168
  /*! \brief Store and update training data's score */
Guolin Ke's avatar
Guolin Ke committed
169
170
171
172
173
174
175
  ScoreUpdater* train_score_updater_;
  /*! \brief Metrics for training data */
  std::vector<const Metric*> training_metrics_;
  /*! \brief Store and update validation data's scores */
  std::vector<ScoreUpdater*> valid_score_updater_;
  /*! \brief Metric for validation data */
  std::vector<std::vector<const Metric*>> valid_metrics_;
wxchan's avatar
wxchan committed
176
177
  /*! \brief Number of rounds for early stopping */
  int early_stopping_round_;
wxchan's avatar
wxchan committed
178
179
180
  /*! \brief Best score(s) for early stopping */
  std::vector<std::vector<int>> best_iter_;
  std::vector<std::vector<score_t>> best_score_;
Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
  /*! \brief Trained models(trees) */
  std::vector<Tree*> models_;
  /*! \brief Max feature index of training data*/
  int max_feature_idx_;
  /*! \brief First order derivative of training data */
  score_t* gradients_;
  /*! \brief Secend order derivative of training data */
  score_t* hessians_;
  /*! \brief Store the data indices of out-of-bag */
  data_size_t* out_of_bag_data_indices_;
  /*! \brief Number of out-of-bag data */
  data_size_t out_of_bag_data_cnt_;
  /*! \brief Store the indices of in-bag data */
  data_size_t* bag_data_indices_;
  /*! \brief Number of in-bag data */
  data_size_t bag_data_cnt_;
  /*! \brief Number of traning data */
  data_size_t num_data_;
199
200
  /*! \brief Number of classes */
  int num_class_;
Guolin Ke's avatar
Guolin Ke committed
201
202
203
204
205
206
  /*! \brief Random generator, used for bagging */
  Random random_;
  /*!
  *   \brief Sigmoid parameter, used for prediction.
  *          if > 0 meas output score will transform by sigmoid function
  */
207
  float sigmoid_;
Guolin Ke's avatar
Guolin Ke committed
208
209
  /*! \brief Index of label column */
  data_size_t label_idx_;
210
211
212
213
  /*! \brief Saved number of models */
  int saved_model_size_ = -1;
  /*! \brief File to write models */
  std::ofstream model_output_file_;
Guolin Ke's avatar
Guolin Ke committed
214
215
216
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
217
#endif   // LightGBM_BOOSTING_GBDT_H_