boosting.h 6.17 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;

/*!
* \brief The interface for Boosting
*/
class Boosting {
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
26
27
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
32
33
34
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
Guolin Ke's avatar
Guolin Ke committed
35
    const ObjectiveFunction* object_function,
36
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
37

wxchan's avatar
wxchan committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  /*!
  * \brief Merge model from other boosting object
           Will insert to the front of current boosting object
  * \param other
  */
  virtual void MergeFrom(const Boosting* other) = 0;

  /*!
  * \brief Reset training data for current boosting
  * \param config Configs for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
  virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;

  /*!
  * \brief Reset shrinkage_rate data for current boosting
  * \param shrinkage_rate Configs for boosting
  */
  virtual void ResetShrinkageRate(double shrinkage_rate) = 0;

Guolin Ke's avatar
Guolin Ke committed
60
61
62
63
64
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
wxchan's avatar
wxchan committed
65
  virtual void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
66
67
    const std::vector<const Metric*>& valid_metrics) = 0;

Guolin Ke's avatar
Guolin Ke committed
68
69
70
71
  /*!
  * \brief Training logic
  * \param gradient nullptr for using default objective, otherwise use self-defined boosting
  * \param hessian nullptr for using default objective, otherwise use self-defined boosting
Guolin Ke's avatar
Guolin Ke committed
72
  * \param is_eval true if need evaluation or early stop
Guolin Ke's avatar
Guolin Ke committed
73
74
  * \return True if meet early stopping or cannot boosting
  */
75
76
  virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;

wxchan's avatar
wxchan committed
77
78
79
80
81
82
83
84
85
86
87
88
89
  /*!
  * \brief Rollback one iteration
  */
  virtual void RollbackOneIter() = 0;

  /*!
  * \brief return current iteration
  */
  virtual int GetCurrentIteration() const = 0;

  /*!
  * \brief Eval metrics and check is met early stopping or not
  */
Guolin Ke's avatar
Guolin Ke committed
90
  virtual bool EvalAndCheckEarlyStopping() = 0;
Guolin Ke's avatar
Guolin Ke committed
91
92
93
94
95
  /*!
  * \brief Get evaluation result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \return evaluation result
  */
96
  virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
97

Guolin Ke's avatar
Guolin Ke committed
98
99
  /*!
  * \brief Get current training score
Guolin Ke's avatar
Guolin Ke committed
100
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
101
102
  * \return training score
  */
103
  virtual const score_t* GetTrainingScore(data_size_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
104

Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \param result used to store prediction result, should allocate memory before call this function
  * \param out_len lenght of returned score
  */
wxchan's avatar
wxchan committed
111
  virtual void GetPredictAt(int data_idx, score_t* result, data_size_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
112

Guolin Ke's avatar
Guolin Ke committed
113
  /*!
Hui Xue's avatar
Hui Xue committed
114
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
115
116
117
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
118
  virtual std::vector<double> PredictRaw(const double* feature_values) const = 0;
Guolin Ke's avatar
Guolin Ke committed
119
120

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
121
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
122
123
124
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
125
  virtual std::vector<double> Predict(const double* feature_values) const = 0;
wxchan's avatar
wxchan committed
126
127
128
129
130
131
  
  /*!
  * \brief Predtion for one record with leaf index
  * \param feature_values Feature value on this record
  * \return Predicted leaf index for this record
  */
132
  virtual std::vector<int> PredictLeafIndex(
133
    const double* feature_values) const = 0;
134

Guolin Ke's avatar
Guolin Ke committed
135
  /*!
wxchan's avatar
wxchan committed
136
137
138
139
140
141
142
143
144
145
  * \brief Dump model to json format string
  * \return Json format string of model
  */
  virtual std::string DumpModel() const = 0;

  /*!
  * \brief Save model to file
  * \param num_used_model Number of model that want to save, -1 means save all
  * \param is_finish Is training finished or not
  * \param filename Filename that want to save to
Guolin Ke's avatar
Guolin Ke committed
146
  */
wxchan's avatar
wxchan committed
147
  virtual void SaveModelToFile(int num_iterations, const char* filename) const = 0;
Guolin Ke's avatar
Guolin Ke committed
148
149
150
151
152

  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
Guolin Ke's avatar
Guolin Ke committed
153
  virtual void LoadModelFromString(const std::string& model_str) = 0;
Guolin Ke's avatar
Guolin Ke committed
154
155
156
157
158
159
160

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
165
166
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
167
168
169
170
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
wxchan's avatar
wxchan committed
171
  virtual int NumberOfTotalModel() const = 0;
172
173
174
175
176
  
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
Guolin Ke's avatar
Guolin Ke committed
177
  virtual int NumberOfClasses() const = 0;
178
179
180
181

  /*!
  * \brief Set number of used model for prediction
  */
wxchan's avatar
wxchan committed
182
  virtual void SetNumIterationForPred(int num_iteration) = 0;
183
  
184
185
186
187
188
  /*!
  * \brief Get Type name of this boosting object
  */
  virtual const char* Name() const = 0;

Guolin Ke's avatar
Guolin Ke committed
189
190
191
192
193
194
  Boosting() = default;
  /*! \brief Disable copy */
  Boosting& operator=(const Boosting&) = delete;
  /*! \brief Disable copy */
  Boosting(const Boosting&) = delete;

wxchan's avatar
wxchan committed
195
196
  static void LoadFileToBoosting(Boosting* boosting, const char* filename);

Guolin Ke's avatar
Guolin Ke committed
197
198
199
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
200
201
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
202
203
  * \return The boosting object
  */
204
205
206
207
208
209
210
211
212
  static Boosting* CreateBoosting(BoostingType type, const char* filename);

  /*!
  * \brief Create boosting object from model file
  * \param filename name of model file
  * \return The boosting object
  */
  static Boosting* CreateBoosting(const char* filename);

Guolin Ke's avatar
Guolin Ke committed
213
214
215
216
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
217
#endif   // LightGBM_BOOSTING_H_