boosting.h 5.87 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;

/*!
* \brief The interface for Boosting
*/
class Boosting {
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
26
27
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
32
33
34
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
Guolin Ke's avatar
Guolin Ke committed
35
    const ObjectiveFunction* object_function,
36
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
37

Guolin Ke's avatar
Guolin Ke committed
38
39
  /*!
  * \brief Merge model from other boosting object
Guolin Ke's avatar
Guolin Ke committed
40
           Will insert to the front of current boosting object
Guolin Ke's avatar
Guolin Ke committed
41
42
43
  * \param other
  */
  virtual void MergeFrom(const Boosting* other) = 0;
44

45
46
  /*!
  * \brief Reset training data for current boosting
47
  * \param config Configs for boosting
48
49
50
51
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
52
  virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
53

Guolin Ke's avatar
Guolin Ke committed
54
55
56
57
58
  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
59
  virtual void AddValidDataset(const Dataset* valid_data,
Guolin Ke's avatar
Guolin Ke committed
60
61
    const std::vector<const Metric*>& valid_metrics) = 0;

Guolin Ke's avatar
Guolin Ke committed
62
63
64
65
  /*!
  * \brief Training logic
  * \param gradient nullptr for using default objective, otherwise use self-defined boosting
  * \param hessian nullptr for using default objective, otherwise use self-defined boosting
Guolin Ke's avatar
Guolin Ke committed
66
  * \param is_eval true if need evaluation or early stop
Guolin Ke's avatar
Guolin Ke committed
67
68
  * \return True if meet early stopping or cannot boosting
  */
69
70
  virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;

71
72
73
74
75
76
77
78
79
80
81
82
83
  /*!
  * \brief Rollback one iteration
  */
  virtual void RollbackOneIter() = 0;

  /*!
  * \brief return current iteration
  */
  virtual int GetCurrentIteration() const = 0;

  /*!
  * \brief Eval metrics and check is met early stopping or not
  */
Guolin Ke's avatar
Guolin Ke committed
84
  virtual bool EvalAndCheckEarlyStopping() = 0;
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
  /*!
  * \brief Get evaluation result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \return evaluation result
  */
90
  virtual std::vector<double> GetEvalAt(int data_idx) const = 0;
91

Guolin Ke's avatar
Guolin Ke committed
92
93
  /*!
  * \brief Get current training score
Guolin Ke's avatar
Guolin Ke committed
94
  * \param out_len length of returned score
Guolin Ke's avatar
Guolin Ke committed
95
96
  * \return training score
  */
97
  virtual const score_t* GetTrainingScore(data_size_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
98

Guolin Ke's avatar
Guolin Ke committed
99
100
101
102
103
104
  /*!
  * \brief Get prediction result at data_idx data
  * \param data_idx 0: training data, 1: 1st validation data
  * \param result used to store prediction result, should allocate memory before call this function
  * \param out_len lenght of returned score
  */
Guolin Ke's avatar
Guolin Ke committed
105
  virtual void GetPredictAt(int data_idx, score_t* result, data_size_t* out_len) = 0;
Guolin Ke's avatar
Guolin Ke committed
106

Guolin Ke's avatar
Guolin Ke committed
107
  /*!
Hui Xue's avatar
Hui Xue committed
108
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
109
110
111
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
112
  virtual std::vector<double> PredictRaw(const double* feature_values) const = 0;
Guolin Ke's avatar
Guolin Ke committed
113
114

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
115
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
116
117
118
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
119
  virtual std::vector<double> Predict(const double* feature_values) const = 0;
wxchan's avatar
wxchan committed
120
121
122
123
124
125
  
  /*!
  * \brief Predtion for one record with leaf index
  * \param feature_values Feature value on this record
  * \return Predicted leaf index for this record
  */
126
  virtual std::vector<int> PredictLeafIndex(
127
    const double* feature_values) const = 0;
128

Guolin Ke's avatar
Guolin Ke committed
129
  /*!
130
  * \brief save model to file
Guolin Ke's avatar
Guolin Ke committed
131
132
133
  * \param num_used_model number of model that want to save, -1 means save all
  * \param is_finish is training finished or not
  * \param filename filename that want to save to
Guolin Ke's avatar
Guolin Ke committed
134
  */
Guolin Ke's avatar
Guolin Ke committed
135
  virtual void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) = 0;
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
140

  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
Guolin Ke's avatar
Guolin Ke committed
141
  virtual void LoadModelFromString(const std::string& model_str) = 0;
Guolin Ke's avatar
Guolin Ke committed
142
143
144
145
146
147
148

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
149
150
151
152
153
154
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
155
156
157
158
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
Guolin Ke's avatar
Guolin Ke committed
159
  virtual int NumberOfTotalModel() const = 0;
160
161
162
163
164
  
  /*!
  * \brief Get number of classes
  * \return Number of classes
  */
Guolin Ke's avatar
Guolin Ke committed
165
  virtual int NumberOfClasses() const = 0;
166
167
168
169

  /*!
  * \brief Set number of used model for prediction
  */
Guolin Ke's avatar
Guolin Ke committed
170
  virtual void SetNumIterationForPred(int num_iteration) = 0;
171
  
172
173
174
175
176
  /*!
  * \brief Get Type name of this boosting object
  */
  virtual const char* Name() const = 0;

Guolin Ke's avatar
Guolin Ke committed
177
178
179
180
181
182
  Boosting() = default;
  /*! \brief Disable copy */
  Boosting& operator=(const Boosting&) = delete;
  /*! \brief Disable copy */
  Boosting(const Boosting&) = delete;

183
  static void LoadFileToBoosting(Boosting* boosting, const char* filename);
Guolin Ke's avatar
Guolin Ke committed
184

Guolin Ke's avatar
Guolin Ke committed
185
186
187
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
188
189
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
190
191
  * \return The boosting object
  */
192
193
194
195
196
197
198
199
200
  static Boosting* CreateBoosting(BoostingType type, const char* filename);

  /*!
  * \brief Create boosting object from model file
  * \param filename name of model file
  * \return The boosting object
  */
  static Boosting* CreateBoosting(const char* filename);

Guolin Ke's avatar
Guolin Ke committed
201
202
203
204
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
205
#endif   // LightGBM_BOOSTING_H_