boosting.h 8.89 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;
16
struct PredictionEarlyStopInstance;
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20

/*!
* \brief The interface for Boosting
*/
21
class LIGHTGBM_EXPORT Boosting {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
27
28
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
29
  * \param train_data Training data
30
  * \param objective_function Training objective function
Guolin Ke's avatar
Guolin Ke committed
31
32
  * \param training_metrics Training metric
  */
33
34
35
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
36
    const ObjectiveFunction* objective_function,
37
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
38

wxchan's avatar
wxchan committed
39
40
  /*!
  * \brief Merge model from other boosting object
Guolin Ke's avatar
Guolin Ke committed
41
  Will insert to the front of current boosting object
wxchan's avatar
wxchan committed
42
43
44
45
  * \param other
  */
  virtual void MergeFrom(const Boosting* other) = 0;

46
47
48
49
  virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
                                 const std::vector<const Metric*>& training_metrics) = 0;

  virtual void ResetConfig(const BoostingConfig* config) = 0;
wxchan's avatar
wxchan committed
50

Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
wxchan's avatar
wxchan committed
56
  virtual void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
57
                               const std::vector<const Metric*>& valid_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
58

Guolin Ke's avatar
Guolin Ke committed
59
60
  virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0;

Guolin Ke's avatar
Guolin Ke committed
61
62
  /*!
  * \brief Training logic
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  * \param gradients nullptr for using default objective, otherwise use self-defined boosting
  * \param hessians nullptr for using default objective, otherwise use self-defined boosting
  * \return True if cannot train anymore
Guolin Ke's avatar
Guolin Ke committed
66
  */
Guolin Ke's avatar
Guolin Ke committed
67
  virtual bool TrainOneIter(const score_t* gradients, const score_t* hessians) = 0;
68

wxchan's avatar
wxchan committed
69
70
71
72
73
74
75
76
77
78
  /*!
  * \brief Rollback one iteration
  */
  virtual void RollbackOneIter() = 0;

  /*!
  * \brief return current iteration
  */
  virtual int GetCurrentIteration() const = 0;

Guolin Ke's avatar
Guolin Ke committed
79
80
81
82
83
  /*!
  * \brief Get evaluation result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \return evaluation result
  */
84
  virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
85

Guolin Ke's avatar
Guolin Ke committed
86
87
  /*!
  * \brief Get current training score
Guolin Ke's avatar
Guolin Ke committed
88
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
89
90
  * \return training score
  */
91
  virtual const double* GetTrainingScore(int64_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
92

Guolin Ke's avatar
Guolin Ke committed
93
94
95
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
96
  * \return out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
97
98
  */
  virtual int64_t GetNumPredictAt(int data_idx) const = 0;
Guolin Ke's avatar
Guolin Ke committed
99

Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \param result used to store prediction result, should allocate memory before call this function
104
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
105
  */
Guolin Ke's avatar
Guolin Ke committed
106
  virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
107

108
  virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
Guolin Ke's avatar
Guolin Ke committed
109

Guolin Ke's avatar
Guolin Ke committed
110
  /*!
Hui Xue's avatar
Hui Xue committed
111
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
112
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
113
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
114
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
Guolin Ke's avatar
Guolin Ke committed
115
  */
cbecker's avatar
cbecker committed
116
  virtual void PredictRaw(const double* features, double* output,
117
                          const PredictionEarlyStopInstance* early_stop) const = 0;
Guolin Ke's avatar
Guolin Ke committed
118
119

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
120
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
121
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
122
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
123
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
Guolin Ke's avatar
Guolin Ke committed
124
  */
cbecker's avatar
cbecker committed
125
  virtual void Predict(const double* features, double* output,
126
                       const PredictionEarlyStopInstance* early_stop) const = 0;
127

wxchan's avatar
wxchan committed
128
  /*!
129
  * \brief Prediction for one record with leaf index
wxchan's avatar
wxchan committed
130
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
131
  * \param output Prediction result for this record
wxchan's avatar
wxchan committed
132
  */
Guolin Ke's avatar
Guolin Ke committed
133
  virtual void PredictLeafIndex(
134
    const double* features, double* output) const = 0;
135

Guolin Ke's avatar
Guolin Ke committed
136
  /*!
137
138
139
  * \brief Feature contributions for the model's prediction of one record
  * \param feature_values Feature value on this record
  * \param output Prediction result for this record
Guolin Ke's avatar
Guolin Ke committed
140
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
141
142
  */
  virtual void PredictContrib(const double* features, double* output,
Guolin Ke's avatar
Guolin Ke committed
143
                              const PredictionEarlyStopInstance* early_stop) const = 0;
144

Guolin Ke's avatar
Guolin Ke committed
145
  /*!
wxchan's avatar
wxchan committed
146
  * \brief Dump model to json format string
147
  * \param num_iteration Number of iterations that want to dump, -1 means dump all
wxchan's avatar
wxchan committed
148
149
  * \return Json format string of model
  */
150
  virtual std::string DumpModel(int num_iteration) const = 0;
wxchan's avatar
wxchan committed
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
  /*!
  * \brief Translate model to if-else statement
  * \param num_iteration Number of iterations that want to translate, -1 means translate all
  * \return if-else format codes of model
  */
  virtual std::string ModelToIfElse(int num_iteration) const = 0;

  /*!
  * \brief Translate model to if-else statement
  * \param num_iteration Number of iterations that want to translate, -1 means translate all
  * \param filename Filename that want to save to
  * \return is_finish Is training finished or not
  */
  virtual bool SaveModelToIfElse(int num_iteration, const char* filename) const = 0;

wxchan's avatar
wxchan committed
167
168
  /*!
  * \brief Save model to file
169
  * \param num_used_model Number of model that want to save, -1 means save all
wxchan's avatar
wxchan committed
170
171
  * \param is_finish Is training finished or not
  * \param filename Filename that want to save to
172
  * \return true if succeeded
Guolin Ke's avatar
Guolin Ke committed
173
  */
174
  virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0;
Guolin Ke's avatar
Guolin Ke committed
175

176
177
  /*!
  * \brief Save model to string
178
  * \param num_used_model Number of model that want to save, -1 means save all
179
180
181
182
  * \return Non-empty string if succeeded
  */
  virtual std::string SaveModelToString(int num_iterations) const = 0;

Guolin Ke's avatar
Guolin Ke committed
183
184
185
  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
186
  * \return true if succeeded
Guolin Ke's avatar
Guolin Ke committed
187
  */
188
  virtual bool LoadModelFromString(const std::string& model_str) = 0;
189
190
191
192
193
194
195
196

  /*!
  * \brief Calculate feature importances
  * \param num_iteration Number of model that want to use for feature importance, -1 means use all
  * \param importance_type: 0 for split, 1 for gain
  * \return vector of feature_importance
  */
  virtual std::vector<double> FeatureImportance(int num_iteration, int importance_type) const = 0;
Guolin Ke's avatar
Guolin Ke committed
197
198
199
200
201
202
203

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

wxchan's avatar
wxchan committed
204
205
206
207
208
209
  /*!
  * \brief Get feature names of this model
  * \return Feature names of this model
  */
  virtual std::vector<std::string> FeatureNames() const = 0;

Guolin Ke's avatar
Guolin Ke committed
210
211
212
213
214
215
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
216
217
218
219
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
wxchan's avatar
wxchan committed
220
  virtual int NumberOfTotalModel() const = 0;
221

Guolin Ke's avatar
Guolin Ke committed
222
  /*!
Guolin Ke's avatar
Guolin Ke committed
223
224
  * \brief Get number of models per iteration
  * \return Number of models per iteration
Guolin Ke's avatar
Guolin Ke committed
225
  */
Guolin Ke's avatar
Guolin Ke committed
226
  virtual int NumModelPerIteration() const = 0;
Guolin Ke's avatar
Guolin Ke committed
227

228
229
230
231
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
Guolin Ke's avatar
Guolin Ke committed
232
  virtual int NumberOfClasses() const = 0;
233

234
235
236
  /*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
  virtual bool NeedAccuratePrediction() const = 0;

237
  /*!
Guolin Ke's avatar
Guolin Ke committed
238
239
  * \brief Initial work for the prediction
  * \param num_iteration number of used iteration
240
  */
241
  virtual void InitPredict(int num_iteration) = 0;
242

243
  /*!
Guolin Ke's avatar
Guolin Ke committed
244
  * \brief Name of submodel
245
  */
Guolin Ke's avatar
Guolin Ke committed
246
  virtual const char* SubModelName() const = 0;
247

Guolin Ke's avatar
Guolin Ke committed
248
249
250
251
252
253
  Boosting() = default;
  /*! \brief Disable copy */
  Boosting& operator=(const Boosting&) = delete;
  /*! \brief Disable copy */
  Boosting(const Boosting&) = delete;

254
  static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
wxchan's avatar
wxchan committed
255

Guolin Ke's avatar
Guolin Ke committed
256
257
258
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
259
260
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
261
262
  * \return The boosting object
  */
263
264
265
266
267
268
269
270
  static Boosting* CreateBoosting(const std::string& type, const char* filename);

  /*!
  * \brief Create boosting object from model file
  * \param filename name of model file
  * \return The boosting object
  */
  static Boosting* CreateBoosting(const char* filename);
271

Guolin Ke's avatar
Guolin Ke committed
272
273
};

Guolin Ke's avatar
Guolin Ke committed
274
275
276
277
278
279
class GBDTBase : public Boosting {
public:
  virtual double GetLeafValue(int tree_idx, int leaf_idx) const = 0;
  virtual void SetLeafValue(int tree_idx, int leaf_idx, double val) = 0;
};

Guolin Ke's avatar
Guolin Ke committed
280
281
}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
282
#endif   // LightGBM_BOOSTING_H_