boosting.h 5.98 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;

/*!
* \brief The interface for Boosting
*/
class Boosting {
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
26
27
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
32
33
34
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
Guolin Ke's avatar
Guolin Ke committed
35
    const ObjectiveFunction* object_function,
36
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
37

Guolin Ke's avatar
Guolin Ke committed
38
39
  /*!
  * \brief Merge model from other boosting object
Guolin Ke's avatar
Guolin Ke committed
40
           Will insert to the front of current boosting object
Guolin Ke's avatar
Guolin Ke committed
41
42
43
  * \param other
  */
  virtual void MergeFrom(const Boosting* other) = 0;
44

45
46
  /*!
  * \brief Reset training data for current boosting
47
  * \param config Configs for boosting
48
49
50
51
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
52
  virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
53

54
55
56
57
58
59
  /*!
  * \brief Reset shrinkage_rate data for current boosting
  * \param shrinkage_rate Configs for boosting
  */
  virtual void ResetShrinkageRate(double shrinkage_rate) = 0;

Guolin Ke's avatar
Guolin Ke committed
60
61
62
63
64
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
65
  virtual void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
66
67
    const std::vector<const Metric*>& valid_metrics) = 0;

Guolin Ke's avatar
Guolin Ke committed
68
69
70
71
  /*!
  * \brief Training logic
  * \param gradient nullptr for using default objective, otherwise use self-defined boosting
  * \param hessian nullptr for using default objective, otherwise use self-defined boosting
Guolin Ke's avatar
Guolin Ke committed
72
  * \param is_eval true if need evaluation or early stop
Guolin Ke's avatar
Guolin Ke committed
73
74
  * \return True if meet early stopping or cannot boosting
  */
75
76
  virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;

77
78
79
80
81
82
83
84
85
86
87
88
89
  /*!
  * \brief Rollback one iteration
  */
  virtual void RollbackOneIter() = 0;

  /*!
  * \brief return current iteration
  */
  virtual int GetCurrentIteration() const = 0;

  /*!
  * \brief Eval metrics and check is met early stopping or not
  */
Guolin Ke's avatar
Guolin Ke committed
90
  virtual bool EvalAndCheckEarlyStopping() = 0;
Guolin Ke's avatar
Guolin Ke committed
91
92
93
94
95
  /*!
  * \brief Get evaluation result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \return evaluation result
  */
96
  virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
97

Guolin Ke's avatar
Guolin Ke committed
98
99
  /*!
  * \brief Get current training score
Guolin Ke's avatar
Guolin Ke committed
100
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
101
102
  * \return training score
  */
103
  virtual const score_t* GetTrainingScore(data_size_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
104

Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \param result used to store prediction result, should allocate memory before call this function
  * \param out_len lenght of returned score
  */
Guolin Ke's avatar
Guolin Ke committed
111
  virtual void GetPredictAt(int data_idx, score_t* result, data_size_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
112

Guolin Ke's avatar
Guolin Ke committed
113
  /*!
Hui Xue's avatar
Hui Xue committed
114
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
115
116
117
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
118
  virtual std::vector<double> PredictRaw(const double* feature_values) const = 0;
Guolin Ke's avatar
Guolin Ke committed
119
120

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
121
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
122
123
124
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
125
  virtual std::vector<double> Predict(const double* feature_values) const = 0;
wxchan's avatar
wxchan committed
126
127
128
129
130
131
  
  /*!
  * \brief Predtion for one record with leaf index
  * \param feature_values Feature value on this record
  * \return Predicted leaf index for this record
  */
132
  virtual std::vector<int> PredictLeafIndex(
133
    const double* feature_values) const = 0;
134

Guolin Ke's avatar
Guolin Ke committed
135
  /*!
136
  * \brief save model to file
Guolin Ke's avatar
Guolin Ke committed
137
  * \param num_iterations Iterations that want to save, -1 means save all
Guolin Ke's avatar
Guolin Ke committed
138
  * \param filename filename that want to save to
Guolin Ke's avatar
Guolin Ke committed
139
  */
Guolin Ke's avatar
Guolin Ke committed
140
  virtual void SaveModelToFile(int num_iterations, const char* filename) const = 0;
Guolin Ke's avatar
Guolin Ke committed
141
142
143
144
145

  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
Guolin Ke's avatar
Guolin Ke committed
146
  virtual void LoadModelFromString(const std::string& model_str) = 0;
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
152
153

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
154
155
156
157
158
159
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
Guolin Ke's avatar
Guolin Ke committed
164
  virtual int NumberOfTotalModel() const = 0;
165
166
167
168
169
  
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
Guolin Ke's avatar
Guolin Ke committed
170
  virtual int NumberOfClasses() const = 0;
171
172
173
174

  /*!
  * \brief Set number of used model for prediction
  */
Guolin Ke's avatar
Guolin Ke committed
175
  virtual void SetNumIterationForPred(int num_iteration) = 0;
176
  
177
178
179
180
181
  /*!
  * \brief Get Type name of this boosting object
  */
  virtual const char* Name() const = 0;

Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
186
187
  Boosting() = default;
  /*! \brief Disable copy */
  Boosting& operator=(const Boosting&) = delete;
  /*! \brief Disable copy */
  Boosting(const Boosting&) = delete;

188
  static void LoadFileToBoosting(Boosting* boosting, const char* filename);
Guolin Ke's avatar
Guolin Ke committed
189

Guolin Ke's avatar
Guolin Ke committed
190
191
192
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
193
194
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
195
196
  * \return The boosting object
  */
197
198
199
200
201
202
203
204
205
  static Boosting* CreateBoosting(BoostingType type, const char* filename);

  /*!
  * \brief Create boosting object from model file
  * \param filename name of model file
  * \return The boosting object
  */
  static Boosting* CreateBoosting(const char* filename);

Guolin Ke's avatar
Guolin Ke committed
206
207
208
209
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
210
#endif   // LightGBM_BOOSTING_H_