boosting.h 9.47 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>
9
#include <map>
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14
15
16

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;
17
struct PredictionEarlyStopInstance;
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21

/*!
* \brief The interface for Boosting
*/
22
class LIGHTGBM_EXPORT Boosting {
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
27
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
28
29
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
30
  * \param train_data Training data
31
  * \param objective_function Training objective function
Guolin Ke's avatar
Guolin Ke committed
32
33
  * \param training_metrics Training metric
  */
34
35
36
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
37
    const ObjectiveFunction* objective_function,
38
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
39

wxchan's avatar
wxchan committed
40
41
  /*!
  * \brief Merge model from other boosting object
Guolin Ke's avatar
Guolin Ke committed
42
  Will insert to the front of current boosting object
wxchan's avatar
wxchan committed
43
44
45
46
  * \param other
  */
  virtual void MergeFrom(const Boosting* other) = 0;

47
48
49
50
  virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
                                 const std::vector<const Metric*>& training_metrics) = 0;

  virtual void ResetConfig(const BoostingConfig* config) = 0;
wxchan's avatar
wxchan committed
51

52
53


Guolin Ke's avatar
Guolin Ke committed
54
55
56
57
58
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
wxchan's avatar
wxchan committed
59
  virtual void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
60
                               const std::vector<const Metric*>& valid_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
61

Guolin Ke's avatar
Guolin Ke committed
62
63
  virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0;

64
65
66
67
68
  /*!
  * \brief Update the tree output by new training data
  */
  virtual void RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction) = 0;

Guolin Ke's avatar
Guolin Ke committed
69
70
  /*!
  * \brief Training logic
Guolin Ke's avatar
Guolin Ke committed
71
72
73
  * \param gradients nullptr for using default objective, otherwise use self-defined boosting
  * \param hessians nullptr for using default objective, otherwise use self-defined boosting
  * \return True if cannot train anymore
Guolin Ke's avatar
Guolin Ke committed
74
  */
Guolin Ke's avatar
Guolin Ke committed
75
  virtual bool TrainOneIter(const score_t* gradients, const score_t* hessians) = 0;
76

wxchan's avatar
wxchan committed
77
78
79
80
81
82
83
84
85
86
  /*!
  * \brief Rollback one iteration
  */
  virtual void RollbackOneIter() = 0;

  /*!
  * \brief return current iteration
  */
  virtual int GetCurrentIteration() const = 0;

Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
91
  /*!
  * \brief Get evaluation result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \return evaluation result
  */
92
  virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
93

Guolin Ke's avatar
Guolin Ke committed
94
95
  /*!
  * \brief Get current training score
Guolin Ke's avatar
Guolin Ke committed
96
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
97
98
  * \return training score
  */
99
  virtual const double* GetTrainingScore(int64_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
100

Guolin Ke's avatar
Guolin Ke committed
101
102
103
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
104
  * \return out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
105
106
  */
  virtual int64_t GetNumPredictAt(int data_idx) const = 0;
Guolin Ke's avatar
Guolin Ke committed
107

Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \param result used to store prediction result, should allocate memory before call this function
112
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
113
  */
Guolin Ke's avatar
Guolin Ke committed
114
  virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
115

116
  virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
Guolin Ke's avatar
Guolin Ke committed
117

Guolin Ke's avatar
Guolin Ke committed
118
  /*!
Hui Xue's avatar
Hui Xue committed
119
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
120
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
121
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
122
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
Guolin Ke's avatar
Guolin Ke committed
123
  */
cbecker's avatar
cbecker committed
124
  virtual void PredictRaw(const double* features, double* output,
125
                          const PredictionEarlyStopInstance* early_stop) const = 0;
Guolin Ke's avatar
Guolin Ke committed
126

Guolin Ke's avatar
Guolin Ke committed
127
128
  virtual void PredictRawByMap(const std::unordered_map<int, double>& features, double* output,
                               const PredictionEarlyStopInstance* early_stop) const = 0;
129
130


Guolin Ke's avatar
Guolin Ke committed
131
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
132
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
133
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
134
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
135
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
Guolin Ke's avatar
Guolin Ke committed
136
  */
cbecker's avatar
cbecker committed
137
  virtual void Predict(const double* features, double* output,
138
                       const PredictionEarlyStopInstance* early_stop) const = 0;
139

Guolin Ke's avatar
Guolin Ke committed
140
141
  virtual void PredictByMap(const std::unordered_map<int, double>& features, double* output,
                            const PredictionEarlyStopInstance* early_stop) const = 0;
142
143


wxchan's avatar
wxchan committed
144
  /*!
145
  * \brief Prediction for one record with leaf index
wxchan's avatar
wxchan committed
146
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
147
  * \param output Prediction result for this record
wxchan's avatar
wxchan committed
148
  */
Guolin Ke's avatar
Guolin Ke committed
149
  virtual void PredictLeafIndex(
150
    const double* features, double* output) const = 0;
151

152
153
154
  virtual void PredictLeafIndexByMap(
    const std::unordered_map<int, double>& features, double* output) const = 0;

Guolin Ke's avatar
Guolin Ke committed
155
  /*!
156
157
158
  * \brief Feature contributions for the model's prediction of one record
  * \param feature_values Feature value on this record
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
159
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
160
161
  */
  virtual void PredictContrib(const double* features, double* output,
Guolin Ke's avatar
Guolin Ke committed
162
                              const PredictionEarlyStopInstance* early_stop) const = 0;
163

Guolin Ke's avatar
Guolin Ke committed
164
  /*!
wxchan's avatar
wxchan committed
165
  * \brief Dump model to json format string
166
  * \param num_iteration Number of iterations that want to dump, -1 means dump all
wxchan's avatar
wxchan committed
167
168
  * \return Json format string of model
  */
169
  virtual std::string DumpModel(int num_iteration) const = 0;
wxchan's avatar
wxchan committed
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
  /*!
  * \brief Translate model to if-else statement
  * \param num_iteration Number of iterations that want to translate, -1 means translate all
  * \return if-else format codes of model
  */
  virtual std::string ModelToIfElse(int num_iteration) const = 0;

  /*!
  * \brief Translate model to if-else statement
  * \param num_iteration Number of iterations that want to translate, -1 means translate all
  * \param filename Filename that want to save to
  * \return is_finish Is training finished or not
  */
  virtual bool SaveModelToIfElse(int num_iteration, const char* filename) const = 0;

wxchan's avatar
wxchan committed
186
187
  /*!
  * \brief Save model to file
wxchan's avatar
wxchan committed
188
  * \param num_iterations Number of model that want to save, -1 means save all
wxchan's avatar
wxchan committed
189
190
  * \param is_finish Is training finished or not
  * \param filename Filename that want to save to
191
  * \return true if succeeded
Guolin Ke's avatar
Guolin Ke committed
192
  */
193
  virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0;
Guolin Ke's avatar
Guolin Ke committed
194

195
196
  /*!
  * \brief Save model to string
wxchan's avatar
wxchan committed
197
  * \param num_iterations Number of model that want to save, -1 means save all
198
199
200
201
  * \return Non-empty string if succeeded
  */
  virtual std::string SaveModelToString(int num_iterations) const = 0;

Guolin Ke's avatar
Guolin Ke committed
202
203
  /*!
  * \brief Restore from a serialized string
204
205
  * \param buffer The content of model
  * \param len The length of buffer
wxchan's avatar
wxchan committed
206
207
  * \return true if succeeded
  */
208
  virtual bool LoadModelFromString(const char* buffer, size_t len) = 0;
wxchan's avatar
wxchan committed
209

210
211
212
213
214
215
216
  /*!
  * \brief Calculate feature importances
  * \param num_iteration Number of model that want to use for feature importance, -1 means use all
  * \param importance_type: 0 for split, 1 for gain
  * \return vector of feature_importance
  */
  virtual std::vector<double> FeatureImportance(int num_iteration, int importance_type) const = 0;
Guolin Ke's avatar
Guolin Ke committed
217
218
219
220
221
222
223

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

wxchan's avatar
wxchan committed
224
225
226
227
228
229
  /*!
  * \brief Get feature names of this model
  * \return Feature names of this model
  */
  virtual std::vector<std::string> FeatureNames() const = 0;

Guolin Ke's avatar
Guolin Ke committed
230
231
232
233
234
235
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
236
237
238
239
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
wxchan's avatar
wxchan committed
240
  virtual int NumberOfTotalModel() const = 0;
241

Guolin Ke's avatar
Guolin Ke committed
242
  /*!
Guolin Ke's avatar
Guolin Ke committed
243
244
  * \brief Get number of models per iteration
  * \return Number of models per iteration
Guolin Ke's avatar
Guolin Ke committed
245
  */
Guolin Ke's avatar
Guolin Ke committed
246
  virtual int NumModelPerIteration() const = 0;
Guolin Ke's avatar
Guolin Ke committed
247

248
249
250
251
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
Guolin Ke's avatar
Guolin Ke committed
252
  virtual int NumberOfClasses() const = 0;
253

254
255
256
  /*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
  virtual bool NeedAccuratePrediction() const = 0;

257
  /*!
Guolin Ke's avatar
Guolin Ke committed
258
259
  * \brief Initial work for the prediction
  * \param num_iteration number of used iteration
260
  * \param is_pred_contrib
261
  */
262
  virtual void InitPredict(int num_iteration, bool is_pred_contrib) = 0;
263

264
  /*!
Guolin Ke's avatar
Guolin Ke committed
265
  * \brief Name of submodel
266
  */
Guolin Ke's avatar
Guolin Ke committed
267
  virtual const char* SubModelName() const = 0;
268

Guolin Ke's avatar
Guolin Ke committed
269
270
271
272
273
274
  Boosting() = default;
  /*! \brief Disable copy */
  Boosting& operator=(const Boosting&) = delete;
  /*! \brief Disable copy */
  Boosting(const Boosting&) = delete;

275
  static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
wxchan's avatar
wxchan committed
276

Guolin Ke's avatar
Guolin Ke committed
277
278
279
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
wxchan's avatar
wxchan committed
280
  * \param format Format of model
281
282
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
283
284
  * \return The boosting object
  */
285
  static Boosting* CreateBoosting(const std::string& type, const char* filename);
286

Guolin Ke's avatar
Guolin Ke committed
287
288
};

Guolin Ke's avatar
Guolin Ke committed
289
290
291
292
293
294
class GBDTBase : public Boosting {
public:
  virtual double GetLeafValue(int tree_idx, int leaf_idx) const = 0;
  virtual void SetLeafValue(int tree_idx, int leaf_idx, double val) = 0;
};

Guolin Ke's avatar
Guolin Ke committed
295
296
}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
297
#endif   // LightGBM_BOOSTING_H_