boosting.h 2.47 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;

/*!
* \brief The interface for Boosting
*/
class Boosting {
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
  * \brief Initial logic
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  * \param output_model_filename Filename of output model
  */
  virtual void Init(const Dataset* train_data,
    const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics,
    const char* output_model_filename) = 0;

  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
  virtual void AddDataset(const Dataset* valid_data,
    const std::vector<const Metric*>& valid_metrics) = 0;

  /*! \brief Training logic */
  virtual void Train() = 0;

  /*!
Hui Xue's avatar
Hui Xue committed
50
  * \brief Prediction for one record, not sigmoid transform
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
56
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
  virtual double PredictRaw(const double * feature_values) const = 0;

  /*!
Hui Xue's avatar
Hui Xue committed
57
  * \brief Prediction for one record, will use sigmoid transform if needed
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
  virtual double Predict(const double * feature_values) const = 0;

  /*!
  * \brief Serialize models by string
  * \return String output of tranined model
  */
  virtual std::string ModelsToString() const = 0;

  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
  virtual void ModelsFromString(const std::string& model_str, int num_used_model) = 0;

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
  virtual int NumberOfSubModels() const = 0;

  /*!
  * \brief Create boosting object
  * \param type Type of boosting
  * \return The boosting object
  */
  static Boosting* CreateBoosting(BoostingType type,
    const BoostingConfig* config);
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
98
#endif   // LightGBM_BOOSTING_H_