boosting.h 5.89 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;

/*!
* \brief The interface for Boosting
*/
class Boosting {
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
26
27
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
32
33
34
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
Guolin Ke's avatar
Guolin Ke committed
35
    const ObjectiveFunction* object_function,
36
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
37

Guolin Ke's avatar
Guolin Ke committed
38
39
40
41
42
  /*!
  * \brief Merge model from other boosting object
  * \param other
  */
  virtual void MergeFrom(const Boosting* other) = 0;
43
44
45
46
47
48
  /*!
  * \brief Reset Config for current boosting
  * \param config Configs for boosting
  */
  virtual void ResetConfig(const BoostingConfig* config) = 0;

49
50
51
52
53
54
55
56
  /*!
  * \brief Reset training data for current boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
  virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;

Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
62
  virtual void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
63
64
    const std::vector<const Metric*>& valid_metrics) = 0;

Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
  /*!
  * \brief Training logic
  * \param gradient nullptr for using default objective, otherwise use self-defined boosting
  * \param hessian nullptr for using default objective, otherwise use self-defined boosting
Guolin Ke's avatar
Guolin Ke committed
69
  * \param is_eval true if need evaluation or early stop
Guolin Ke's avatar
Guolin Ke committed
70
71
  * \return True if meet early stopping or cannot boosting
  */
72
73
  virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;

74
75
76
77
78
79
80
81
82
83
84
85
86
  /*!
  * \brief Rollback one iteration
  */
  virtual void RollbackOneIter() = 0;

  /*!
  * \brief return current iteration
  */
  virtual int GetCurrentIteration() const = 0;

  /*!
  * \brief Eval metrics and check is met early stopping or not
  */
Guolin Ke's avatar
Guolin Ke committed
87
  virtual bool EvalAndCheckEarlyStopping() = 0;
Guolin Ke's avatar
Guolin Ke committed
88
89
90
91
92
  /*!
  * \brief Get evaluation result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \return evaluation result
  */
93
  virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
94

Guolin Ke's avatar
Guolin Ke committed
95
96
  /*!
  * \brief Get current training score
Guolin Ke's avatar
Guolin Ke committed
97
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
98
99
  * \return training score
  */
100
  virtual const score_t* GetTrainingScore(data_size_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
101

Guolin Ke's avatar
Guolin Ke committed
102
103
104
105
106
107
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \param result used to store prediction result, should allocate memory before call this function
  * \param out_len lenght of returned score
  */
Guolin Ke's avatar
Guolin Ke committed
108
  virtual void GetPredictAt(int data_idx, score_t* result, data_size_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
109

Guolin Ke's avatar
Guolin Ke committed
110
  /*!
Hui Xue's avatar
Hui Xue committed
111
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
112
113
114
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
115
  virtual std::vector<double> PredictRaw(const double* feature_values) const = 0;
Guolin Ke's avatar
Guolin Ke committed
116
117

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
118
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
119
120
121
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
122
  virtual std::vector<double> Predict(const double* feature_values) const = 0;
wxchan's avatar
wxchan committed
123
124
125
126
127
128
  
  /*!
  * \brief Predtion for one record with leaf index
  * \param feature_values Feature value on this record
  * \return Predicted leaf index for this record
  */
129
  virtual std::vector<int> PredictLeafIndex(
130
    const double* feature_values) const = 0;
131

Guolin Ke's avatar
Guolin Ke committed
132
  /*!
133
  * \brief save model to file
Guolin Ke's avatar
Guolin Ke committed
134
135
136
  * \param num_used_model number of model that want to save, -1 means save all
  * \param is_finish is training finished or not
  * \param filename filename that want to save to
Guolin Ke's avatar
Guolin Ke committed
137
  */
Guolin Ke's avatar
Guolin Ke committed
138
  virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) = 0;
Guolin Ke's avatar
Guolin Ke committed
139
140
141
142
143

  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
Guolin Ke's avatar
Guolin Ke committed
144
  virtual void LoadModelFromString(const std::string& model_str) = 0;
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
152
153
154
155
156
157
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
158
159
160
161
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
Guolin Ke's avatar
Guolin Ke committed
162
  virtual int NumberOfTotalModel() const = 0;
163
164
165
166
167
  
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
Guolin Ke's avatar
Guolin Ke committed
168
  virtual int NumberOfClasses() const = 0;
169
170
171
172

  /*!
  * \brief Set number of used model for prediction
  */
Guolin Ke's avatar
Guolin Ke committed
173
  virtual void SetNumIterationForPred(int num_iteration) = 0;
174
  
175
176
177
178
179
  /*!
  * \brief Get Type name of this boosting object
  */
  virtual const char* Name() const = 0;

Guolin Ke's avatar
Guolin Ke committed
180
181
182
183
184
185
  Boosting() = default;
  /*! \brief Disable copy */
  Boosting& operator=(const Boosting&) = delete;
  /*! \brief Disable copy */
  Boosting(const Boosting&) = delete;

186
  static void LoadFileToBoosting(Boosting* boosting, const char* filename);
Guolin Ke's avatar
Guolin Ke committed
187

Guolin Ke's avatar
Guolin Ke committed
188
189
190
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
191
192
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
193
194
  * \return The boosting object
  */
195
196
197
198
199
200
201
202
203
  static Boosting* CreateBoosting(BoostingType type, const char* filename);

  /*!
  * \brief Create boosting object from model file
  * \param filename name of model file
  * \return The boosting object
  */
  static Boosting* CreateBoosting(const char* filename);

Guolin Ke's avatar
Guolin Ke committed
204
205
206
207
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
208
#endif   // LightGBM_BOOSTING_H_