boosting.h 2.47 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#ifndef LIGHTGBM_BOOSTING_H_
#define LIGHTGBM_BOOSTING_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>
#include <string>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class ObjectiveFunction;
class Metric;

/*!
* \brief The interface for Boosting
*/
class Boosting {
public:
  /*! \brief virtual destructor */
  virtual ~Boosting() {}

  /*!
  * \brief Initial logic
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metric
  * \param output_model_filename Filename of output model
  */
  virtual void Init(const Dataset* train_data,
    const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics,
    const char* output_model_filename) = 0;

  /*!
  * \brief Add a validation data
  * \param valid_data Validation data
  * \param valid_metrics Metric for validation data
  */
  virtual void AddDataset(const Dataset* valid_data,
    const std::vector<const Metric*>& valid_metrics) = 0;

  /*! \brief Training logic */
  virtual void Train() = 0;

  /*!
  * \brief Predtion for one record, not sigmoid transform
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
  virtual double PredictRaw(const double * feature_values) const = 0;

  /*!
  * \brief Predtion for one record, will use sigmoid transform if needed
  * \param feature_values Feature value on this record
  * \return Prediction result for this record
  */
  virtual double Predict(const double * feature_values) const = 0;

  /*!
  * \brief Serialize models by string
  * \return String output of tranined model
  */
  virtual std::string ModelsToString() const = 0;

  /*!
  * \brief Restore from a serialized string
  * \param model_str The string of model
  */
  virtual void ModelsFromString(const std::string& model_str, int num_used_model) = 0;

  /*!
  * \brief Get max feature index of this model
  * \return Max feature index of this model
  */
  virtual int MaxFeatureIdx() const = 0;

  /*!
  * \brief Get number of weak sub-models
  * \return Number of weak sub-models
  */
  virtual int NumberOfSubModels() const = 0;

  /*!
  * \brief Create boosting object
  * \param type Type of boosting
  * \return The boosting object
  */
  static Boosting* CreateBoosting(BoostingType type,
    const BoostingConfig* config);
};

}  // namespace LightGBM

#endif  #endif  // LightGBM_BOOSTING_H_