boosting.h 3.66 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;

/*!
* \brief The interface for Boosting
*/
class Boosting {
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
26
27
  * \brief Initialization logic
  * \param config Configs for boosting
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  */
32
33
34
  virtual void Init(
    const BoostingConfig* config,
    const Dataset* train_data,
Guolin Ke's avatar
Guolin Ke committed
35
    const ObjectiveFunction* object_function,
36
    const std::vector<const Metric*>& training_metrics) = 0;
Guolin Ke's avatar
Guolin Ke committed
37
38
39
40
41
42
43
44
45
46

  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
  virtual void AddDataset(const Dataset* valid_data,
    const std::vector<const Metric*>& valid_metrics) = 0;

  /*! \brief Training logic */
47
48
49
50
51
52
53
  virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;

  /*! \brief Get eval result */
  virtual std::vector<std::string> EvalCurrent(bool is_eval_train) const = 0 ;

  /*! \brief Get prediction result */
  virtual const std::vector<const score_t*> PredictCurrent(bool is_predict_train) const = 0;
Guolin Ke's avatar
Guolin Ke committed
54
55

  /*!
Hui Xue's avatar
Hui Xue committed
56
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
57
  * \param feature_values Feature value on this record
58
  * \param num_used_model Number of used model
Guolin Ke's avatar
Guolin Ke committed
59
60
  * \return Prediction result for this record
  */
61
62
  virtual float PredictRaw(const float* feature_values,
    int num_used_model) const = 0;
Guolin Ke's avatar
Guolin Ke committed
63
64

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
65
  * \brief Prediction for one record, sigmoid transformation will be used if needed
Guolin Ke's avatar
Guolin Ke committed
66
  * \param feature_values Feature value on this record
67
  * \param num_used_model Number of used model
Guolin Ke's avatar
Guolin Ke committed
68
69
  * \return Prediction result for this record
  */
70
71
  virtual float Predict(const float* feature_values, 
    int num_used_model) const = 0;
wxchan's avatar
wxchan committed
72
73
74
75
  
  /*!
  * \brief Predtion for one record with leaf index
  * \param feature_values Feature value on this record
76
  * \param num_used_model Number of used model
wxchan's avatar
wxchan committed
77
78
  * \return Predicted leaf index for this record
  */
79
80
81
  virtual std::vector<int> PredictLeafIndex(
    const float* feature_values,
    int num_used_model) const = 0;
wxchan's avatar
wxchan committed
82
  
Guolin Ke's avatar
Guolin Ke committed
83
  /*!
84
  * \brief save model to file
Guolin Ke's avatar
Guolin Ke committed
85
  */
86
  virtual void SaveModelToFile(bool is_finish, const char* filename) = 0;
Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
91

  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
92
  virtual void ModelsFromString(const std::string& model_str) = 0;
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
97
98
99

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
104
105
  /*!
  * \brief Get index of label column
  * \return index of label column
  */
  virtual int LabelIdx() const = 0;

Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
  virtual int NumberOfSubModels() const = 0;

112
113
114
115
116
  /*!
  * \brief Get Type name of this boosting object
  */
  virtual const char* Name() const = 0;

Guolin Ke's avatar
Guolin Ke committed
117
118
119
  /*!
  * \brief Create boosting object
  * \param type Type of boosting
120
121
  * \param config config for boosting
  * \param filename name of model file, if existing will continue to train from this model
Guolin Ke's avatar
Guolin Ke committed
122
123
  * \return The boosting object
  */
124
125
126
127
128
129
130
131
132
  static Boosting* CreateBoosting(BoostingType type, const char* filename);

  /*!
  * \brief Create boosting object from model file
  * \param filename name of model file
  * \return The boosting object
  */
  static Boosting* CreateBoosting(const char* filename);

Guolin Ke's avatar
Guolin Ke committed
133
134
135
136
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
137
#endif   // LightGBM_BOOSTING_H_