tree_learner.h 1.48 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#ifndef LIGHTGBM_TREE_LEARNER_H_
#define LIGHTGBM_TREE_LEARNER_H_


#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>

namespace LightGBM {

/*! \brief forward declaration */
class Tree;
class Dataset;

/*!
* \brief Interface for tree learner
*/
class TreeLearner {
public:
  /*! \brief virtual destructor */
  virtual ~TreeLearner() {}

  /*!
  * \brief Init tree learner with training data set and tree config
  * \param train_data The used training data
  * \param tree_config The tree setting
  */
  virtual void Init(const Dataset* train_data) = 0;

  /*!
  * \brief fit train data set and return a trained tree
  * \param gradients The first order gradients
  * \param hessians The second order gradients
  * \return A trained tree
  */
  virtual Tree* Train(const score_t* gradients, const score_t* hessians) = 0;

  /*!
Guolin Ke's avatar
Guolin Ke committed
40
  * \brief Set bagging data
Guolin Ke's avatar
Guolin Ke committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  * \param used_indices Used data indices
  * \param num_data Number of used data
  */
  virtual void SetBaggingData(const data_size_t* used_indices,
    data_size_t num_data) = 0;

  /*!
  * \brief Use last trained tree to predition training score, and add to out_score;
  * \param out_score output score
  */
  virtual void AddPredictionToScore(score_t *out_score) const = 0;

  /*!
  * \brief Create object of tree learner
  * \param type Type of tree learner
  */
  static TreeLearner* CreateTreeLearner(TreeLearnerType type,
    const TreeConfig& tree_config);
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
63
#endif   // LightGBM_TREE_LEARNER_H_