boosting.h 9.28 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>
wxchan's avatar
wxchan committed
6
#include "model.pb.h"
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;
17
struct PredictionEarlyStopInstance;
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21

/*!
* \brief The interface for Boosting
*/
22
class LIGHTGBM_EXPORT Boosting {
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
27
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
28
29
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
30
  * \param train_data Training data
31
  * \param objective_function Training objective function
Guolin Ke's avatar
Guolin Ke committed
32
33
  * \param training_metrics Training metric
  */
34
35
36
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
37
    const ObjectiveFunction* objective_function,
38
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
39

wxchan's avatar
wxchan committed
40
41
  /*!
  * \brief Merge model from other boosting object
Guolin Ke's avatar
Guolin Ke committed
42
  Will insert to the front of current boosting object
wxchan's avatar
wxchan committed
43
44
45
46
  * \param other
  */
  virtual void MergeFrom(const Boosting* other) = 0;

47
48
49
50
  virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
                                 const std::vector<const Metric*>& training_metrics) = 0;

  virtual void ResetConfig(const BoostingConfig* config) = 0;
wxchan's avatar
wxchan committed
51

Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
wxchan's avatar
wxchan committed
57
  virtual void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
58
                               const std::vector<const Metric*>& valid_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
59

Guolin Ke's avatar
Guolin Ke committed
60
61
  virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0;

Guolin Ke's avatar
Guolin Ke committed
62
63
  /*!
  * \brief Training logic
Guolin Ke's avatar
Guolin Ke committed
64
65
66
  * \param gradients nullptr for using default objective, otherwise use self-defined boosting
  * \param hessians nullptr for using default objective, otherwise use self-defined boosting
  * \return True if cannot train anymore
Guolin Ke's avatar
Guolin Ke committed
67
  */
Guolin Ke's avatar
Guolin Ke committed
68
  virtual bool TrainOneIter(const score_t* gradients, const score_t* hessians) = 0;
69

wxchan's avatar
wxchan committed
70
71
72
73
74
75
76
77
78
79
  /*!
  * \brief Rollback one iteration
  */
  virtual void RollbackOneIter() = 0;

  /*!
  * \brief return current iteration
  */
  virtual int GetCurrentIteration() const = 0;

Guolin Ke's avatar
Guolin Ke committed
80
81
82
83
84
  /*!
  * \brief Get evaluation result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \return evaluation result
  */
85
  virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
86

Guolin Ke's avatar
Guolin Ke committed
87
88
  /*!
  * \brief Get current training score
Guolin Ke's avatar
Guolin Ke committed
89
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
90
91
  * \return training score
  */
92
  virtual const double* GetTrainingScore(int64_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
93

Guolin Ke's avatar
Guolin Ke committed
94
95
96
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
97
  * \return out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
98
99
  */
  virtual int64_t GetNumPredictAt(int data_idx) const = 0;
Guolin Ke's avatar
Guolin Ke committed
100

Guolin Ke's avatar
Guolin Ke committed
101
102
103
104
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \param result used to store prediction result, should allocate memory before call this function
105
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
106
  */
Guolin Ke's avatar
Guolin Ke committed
107
  virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
108

109
  virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
Guolin Ke's avatar
Guolin Ke committed
110

Guolin Ke's avatar
Guolin Ke committed
111
  /*!
Hui Xue's avatar
Hui Xue committed
112
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
113
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
114
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
115
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
Guolin Ke's avatar
Guolin Ke committed
116
  */
cbecker's avatar
cbecker committed
117
  virtual void PredictRaw(const double* features, double* output,
118
                          const PredictionEarlyStopInstance* early_stop) const = 0;
Guolin Ke's avatar
Guolin Ke committed
119
120

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
121
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
122
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
123
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
124
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
Guolin Ke's avatar
Guolin Ke committed
125
  */
cbecker's avatar
cbecker committed
126
  virtual void Predict(const double* features, double* output,
127
                       const PredictionEarlyStopInstance* early_stop) const = 0;
128

wxchan's avatar
wxchan committed
129
  /*!
130
  * \brief Prediction for one record with leaf index
wxchan's avatar
wxchan committed
131
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
132
  * \param output Prediction result for this record
wxchan's avatar
wxchan committed
133
  */
Guolin Ke's avatar
Guolin Ke committed
134
  virtual void PredictLeafIndex(
135
    const double* features, double* output) const = 0;
136

Guolin Ke's avatar
Guolin Ke committed
137
  /*!
138
139
140
  * \brief Feature contributions for the model's prediction of one record
  * \param feature_values Feature value on this record
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
141
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
142
143
  */
  virtual void PredictContrib(const double* features, double* output,
Guolin Ke's avatar
Guolin Ke committed
144
                              const PredictionEarlyStopInstance* early_stop) const = 0;
145

Guolin Ke's avatar
Guolin Ke committed
146
  /*!
wxchan's avatar
wxchan committed
147
  * \brief Dump model to json format string
148
  * \param num_iteration Number of iterations that want to dump, -1 means dump all
wxchan's avatar
wxchan committed
149
150
  * \return Json format string of model
  */
151
  virtual std::string DumpModel(int num_iteration) const = 0;
wxchan's avatar
wxchan committed
152

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
  /*!
  * \brief Translate model to if-else statement
  * \param num_iteration Number of iterations that want to translate, -1 means translate all
  * \return if-else format codes of model
  */
  virtual std::string ModelToIfElse(int num_iteration) const = 0;

  /*!
  * \brief Translate model to if-else statement
  * \param num_iteration Number of iterations that want to translate, -1 means translate all
  * \param filename Filename that want to save to
  * \return is_finish Is training finished or not
  */
  virtual bool SaveModelToIfElse(int num_iteration, const char* filename) const = 0;

wxchan's avatar
wxchan committed
168
169
  /*!
  * \brief Save model to file
wxchan's avatar
wxchan committed
170
  * \param num_iterations Number of model that want to save, -1 means save all
wxchan's avatar
wxchan committed
171
172
  * \param is_finish Is training finished or not
  * \param filename Filename that want to save to
173
  * \return true if succeeded
Guolin Ke's avatar
Guolin Ke committed
174
  */
175
  virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0;
Guolin Ke's avatar
Guolin Ke committed
176

177
178
  /*!
  * \brief Save model to string
wxchan's avatar
wxchan committed
179
  * \param num_iterations Number of model that want to save, -1 means save all
180
181
182
183
  * \return Non-empty string if succeeded
  */
  virtual std::string SaveModelToString(int num_iterations) const = 0;

Guolin Ke's avatar
Guolin Ke committed
184
185
186
  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
187
  * \return true if succeeded
Guolin Ke's avatar
Guolin Ke committed
188
  */
189
  virtual bool LoadModelFromString(const std::string& model_str) = 0;
190

wxchan's avatar
wxchan committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
  /*!
  * \brief Save model with protobuf
  * \param num_iterations Number of model that want to save, -1 means save all
  * \param filename Filename that want to save to
  */
  virtual void SaveModelToProto(int num_iteration, const char* filename) const = 0;
  
  /*!
  * \brief Restore from a serialized protobuf file
  * \param filename Filename that want to restore from
  * \return true if succeeded
  */
  virtual bool LoadModelFromProto(const char* filename) = 0;

205
206
207
208
209
210
211
  /*!
  * \brief Calculate feature importances
  * \param num_iteration Number of model that want to use for feature importance, -1 means use all
  * \param importance_type: 0 for split, 1 for gain
  * \return vector of feature_importance
  */
  virtual std::vector<double> FeatureImportance(int num_iteration, int importance_type) const = 0;
Guolin Ke's avatar
Guolin Ke committed
212
213
214
215
216
217
218

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

wxchan's avatar
wxchan committed
219
220
221
222
223
224
  /*!
  * \brief Get feature names of this model
  * \return Feature names of this model
  */
  virtual std::vector<std::string> FeatureNames() const = 0;

Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
229
230
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
231
232
233
234
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
wxchan's avatar
wxchan committed
235
  virtual int NumberOfTotalModel() const = 0;
236

Guolin Ke's avatar
Guolin Ke committed
237
  /*!
Guolin Ke's avatar
Guolin Ke committed
238
239
  * \brief Get number of models per iteration
  * \return Number of models per iteration
Guolin Ke's avatar
Guolin Ke committed
240
  */
Guolin Ke's avatar
Guolin Ke committed
241
  virtual int NumModelPerIteration() const = 0;
Guolin Ke's avatar
Guolin Ke committed
242

243
244
245
246
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
Guolin Ke's avatar
Guolin Ke committed
247
  virtual int NumberOfClasses() const = 0;
248

249
250
251
  /*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
  virtual bool NeedAccuratePrediction() const = 0;

252
  /*!
Guolin Ke's avatar
Guolin Ke committed
253
254
  * \brief Initial work for the prediction
  * \param num_iteration number of used iteration
255
  */
256
  virtual void InitPredict(int num_iteration) = 0;
257

258
  /*!
Guolin Ke's avatar
Guolin Ke committed
259
  * \brief Name of submodel
260
  */
Guolin Ke's avatar
Guolin Ke committed
261
  virtual const char* SubModelName() const = 0;
262

Guolin Ke's avatar
Guolin Ke committed
263
264
265
266
267
268
  Boosting() = default;
  /*! \brief Disable copy */
  Boosting& operator=(const Boosting&) = delete;
  /*! \brief Disable copy */
  Boosting(const Boosting&) = delete;

wxchan's avatar
wxchan committed
269
  static bool LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename);
wxchan's avatar
wxchan committed
270

Guolin Ke's avatar
Guolin Ke committed
271
272
273
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
wxchan's avatar
wxchan committed
274
  * \param format Format of model
275
276
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
277
278
  * \return The boosting object
  */
wxchan's avatar
wxchan committed
279
  static Boosting* CreateBoosting(const std::string& type, const std::string& format, const char* filename);
280

Guolin Ke's avatar
Guolin Ke committed
281
282
};

Guolin Ke's avatar
Guolin Ke committed
283
284
285
286
287
288
class GBDTBase : public Boosting {
public:
  virtual double GetLeafValue(int tree_idx, int leaf_idx) const = 0;
  virtual void SetLeafValue(int tree_idx, int leaf_idx, double val) = 0;
};

Guolin Ke's avatar
Guolin Ke committed
289
290
}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
291
#endif   // LightGBM_BOOSTING_H_