boosting.h 8.54 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;
16
struct PredictionEarlyStopInstance;
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20

/*!
* \brief The interface for Boosting
*/
21
class LIGHTGBM_EXPORT Boosting {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
27
28
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
29
  * \param train_data Training data
30
  * \param objective_function Training objective function
Guolin Ke's avatar
Guolin Ke committed
31
32
  * \param training_metrics Training metric
  */
33
34
35
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
36
    const ObjectiveFunction* objective_function,
37
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
38

wxchan's avatar
wxchan committed
39
40
41
42
43
44
45
  /*!
  * \brief Merge model from other boosting object
           Will insert to the front of current boosting object
  * \param other
  */
  virtual void MergeFrom(const Boosting* other) = 0;

46
47
48
49
  virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
                                 const std::vector<const Metric*>& training_metrics) = 0;

  virtual void ResetConfig(const BoostingConfig* config) = 0;
wxchan's avatar
wxchan committed
50

Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
wxchan's avatar
wxchan committed
56
  virtual void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
57
58
    const std::vector<const Metric*>& valid_metrics) = 0;

Guolin Ke's avatar
Guolin Ke committed
59
60
  virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0;

Guolin Ke's avatar
Guolin Ke committed
61
62
63
64
  /*!
  * \brief Training logic
  * \param gradient nullptr for using default objective, otherwise use self-defined boosting
  * \param hessian nullptr for using default objective, otherwise use self-defined boosting
Guolin Ke's avatar
Guolin Ke committed
65
  * \param is_eval true if need evaluation or early stop
Guolin Ke's avatar
Guolin Ke committed
66
67
  * \return True if meet early stopping or cannot boosting
  */
68
  virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;
69

wxchan's avatar
wxchan committed
70
71
72
73
74
75
76
77
78
79
80
81
82
  /*!
  * \brief Rollback one iteration
  */
  virtual void RollbackOneIter() = 0;

  /*!
  * \brief return current iteration
  */
  virtual int GetCurrentIteration() const = 0;

  /*!
  * \brief Eval metrics and check is met early stopping or not
  */
Guolin Ke's avatar
Guolin Ke committed
83
  virtual bool EvalAndCheckEarlyStopping() = 0;
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
88
  /*!
  * \brief Get evaluation result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \return evaluation result
  */
89
  virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
90

Guolin Ke's avatar
Guolin Ke committed
91
92
  /*!
  * \brief Get current training score
Guolin Ke's avatar
Guolin Ke committed
93
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
94
95
  * \return training score
  */
96
  virtual const double* GetTrainingScore(int64_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
97

Guolin Ke's avatar
Guolin Ke committed
98
99
100
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
101
  * \return out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
102
103
  */
  virtual int64_t GetNumPredictAt(int data_idx) const = 0;
Guolin Ke's avatar
Guolin Ke committed
104
105
106
107
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \param result used to store prediction result, should allocate memory before call this function
108
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
109
  */
Guolin Ke's avatar
Guolin Ke committed
110
  virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
111

112
  virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
Guolin Ke's avatar
Guolin Ke committed
113

Guolin Ke's avatar
Guolin Ke committed
114
  /*!
Hui Xue's avatar
Hui Xue committed
115
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
116
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
117
  * \param output Prediction result for this record
118
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
Guolin Ke's avatar
Guolin Ke committed
119
  */
cbecker's avatar
cbecker committed
120
  virtual void PredictRaw(const double* features, double* output,
121
                          const PredictionEarlyStopInstance* early_stop) const = 0;
Guolin Ke's avatar
Guolin Ke committed
122
123

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
124
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
125
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
126
  * \param output Prediction result for this record
127
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
Guolin Ke's avatar
Guolin Ke committed
128
  */
cbecker's avatar
cbecker committed
129
  virtual void Predict(const double* features, double* output,
130
                       const PredictionEarlyStopInstance* early_stop) const = 0;
131

wxchan's avatar
wxchan committed
132
  /*!
133
  * \brief Prediction for one record with leaf index
wxchan's avatar
wxchan committed
134
  * \param feature_values Feature value on this record
Guolin Ke's avatar
Guolin Ke committed
135
  * \param output Prediction result for this record
wxchan's avatar
wxchan committed
136
  */
Guolin Ke's avatar
Guolin Ke committed
137
  virtual void PredictLeafIndex(
138
    const double* features, double* output) const = 0;
139

140
141
142
143
144
145
146
147
148
 /*!
  * \brief Feature contributions for the model's prediction of one record
  * \param feature_values Feature value on this record
  * \param output Prediction result for this record
  * \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
  */
  virtual void PredictContrib(const double* features, double* output,
                       const PredictionEarlyStopInstance* early_stop) const = 0;

Guolin Ke's avatar
Guolin Ke committed
149
  /*!
wxchan's avatar
wxchan committed
150
  * \brief Dump model to json format string
151
  * \param num_iteration Number of iterations that want to dump, -1 means dump all
wxchan's avatar
wxchan committed
152
153
  * \return Json format string of model
  */
154
  virtual std::string DumpModel(int num_iteration) const = 0;
wxchan's avatar
wxchan committed
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
  /*!
  * \brief Translate model to if-else statement
  * \param num_iteration Number of iterations that want to translate, -1 means translate all
  * \return if-else format codes of model
  */
  virtual std::string ModelToIfElse(int num_iteration) const = 0;

  /*!
  * \brief Translate model to if-else statement
  * \param num_iteration Number of iterations that want to translate, -1 means translate all
  * \param filename Filename that want to save to
  * \return is_finish Is training finished or not
  */
  virtual bool SaveModelToIfElse(int num_iteration, const char* filename) const = 0;

wxchan's avatar
wxchan committed
171
172
173
174
175
  /*!
  * \brief Save model to file
  * \param num_used_model Number of model that want to save, -1 means save all
  * \param is_finish Is training finished or not
  * \param filename Filename that want to save to
176
  * \return true if succeeded
Guolin Ke's avatar
Guolin Ke committed
177
  */
178
  virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0;
Guolin Ke's avatar
Guolin Ke committed
179

180
181
182
183
184
185
186
  /*!
  * \brief Save model to string
  * \param num_used_model Number of model that want to save, -1 means save all
  * \return Non-empty string if succeeded
  */
  virtual std::string SaveModelToString(int num_iterations) const = 0;

Guolin Ke's avatar
Guolin Ke committed
187
188
189
  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
190
  * \return true if succeeded
Guolin Ke's avatar
Guolin Ke committed
191
  */
192
  virtual bool LoadModelFromString(const std::string& model_str) = 0;
Guolin Ke's avatar
Guolin Ke committed
193
194
195
196
197
198
199

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

wxchan's avatar
wxchan committed
200
201
202
203
204
205
  /*!
  * \brief Get feature names of this model
  * \return Feature names of this model
  */
  virtual std::vector<std::string> FeatureNames() const = 0;

Guolin Ke's avatar
Guolin Ke committed
206
207
208
209
210
211
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
212
213
214
215
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
wxchan's avatar
wxchan committed
216
  virtual int NumberOfTotalModel() const = 0;
217

Guolin Ke's avatar
Guolin Ke committed
218
219
220
221
222
223
  /*!
  * \brief Get number of trees per iteration
  * \return Number of trees per iteration
  */
  virtual int NumTreePerIteration() const = 0;

224
225
226
227
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
Guolin Ke's avatar
Guolin Ke committed
228
  virtual int NumberOfClasses() const = 0;
229

230
231
232
  /*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
  virtual bool NeedAccuratePrediction() const = 0;

233
  /*!
Guolin Ke's avatar
Guolin Ke committed
234
235
  * \brief Initial work for the prediction
  * \param num_iteration number of used iteration
236
  */
237
  virtual void InitPredict(int num_iteration) = 0;
238

239
  /*!
Guolin Ke's avatar
Guolin Ke committed
240
  * \brief Name of submodel
241
  */
Guolin Ke's avatar
Guolin Ke committed
242
  virtual const char* SubModelName() const = 0;
243

Guolin Ke's avatar
Guolin Ke committed
244
245
246
247
248
249
  Boosting() = default;
  /*! \brief Disable copy */
  Boosting& operator=(const Boosting&) = delete;
  /*! \brief Disable copy */
  Boosting(const Boosting&) = delete;

250
  static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
wxchan's avatar
wxchan committed
251

Guolin Ke's avatar
Guolin Ke committed
252
253
254
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
255
256
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
257
258
  * \return The boosting object
  */
Guolin Ke's avatar
Guolin Ke committed
259
  static Boosting* CreateBoosting(const std::string& type, const char* filename);
260
261
262
263
264
265
266
267

  /*!
  * \brief Create boosting object from model file
  * \param filename name of model file
  * \return The boosting object
  */
  static Boosting* CreateBoosting(const char* filename);

Guolin Ke's avatar
Guolin Ke committed
268
269
270
271
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
272
#endif   // LightGBM_BOOSTING_H_