score_updater.hpp 3.11 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#ifndef LIGHTGBM_BOOSTING_SCORE_UPDATER_HPP_
#define LIGHTGBM_BOOSTING_SCORE_UPDATER_HPP_

#include <LightGBM/meta.h>
#include <LightGBM/dataset.h>
#include <LightGBM/tree.h>
#include <LightGBM/tree_learner.h>

#include <cstring>

namespace LightGBM {
/*!
* \brief Used to store and update score for data
*/
class ScoreUpdater {
public:
  /*!
  * \brief Constructor, will pass a const pointer of dataset
  * \param data This class will bind with this data set
  */
Guolin Ke's avatar
Guolin Ke committed
21
  ScoreUpdater(const Dataset* data, int num_class) : data_(data) {
Guolin Ke's avatar
Guolin Ke committed
22
    num_data_ = data->num_data();
23
24
    size_t total_size = static_cast<size_t>(num_data_) * num_class;
    score_.resize(total_size);
Guolin Ke's avatar
Guolin Ke committed
25
    // default start score is zero
Guolin Ke's avatar
Guolin Ke committed
26
    std::fill(score_.begin(), score_.end(), 0.0f);
Guolin Ke's avatar
Guolin Ke committed
27
    const float* init_score = data->metadata().init_score();
Guolin Ke's avatar
Guolin Ke committed
28
29
    // if exists initial score, will start from it
    if (init_score != nullptr) {
30
      for (size_t i = 0; i < total_size; ++i) {
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
35
36
        score_[i] = init_score[i];
      }
    }
  }
  /*! \brief Destructor */
  ~ScoreUpdater() {
Guolin Ke's avatar
Guolin Ke committed
37

Guolin Ke's avatar
Guolin Ke committed
38
39
  }
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
40
41
  * \brief Using tree model to get prediction number, then adding to scores for all data
  *        Note: this function generally will be used on validation data too.
Guolin Ke's avatar
Guolin Ke committed
42
  * \param tree Trained tree model
43
  * \param curr_class Current class for multiclass training
Guolin Ke's avatar
Guolin Ke committed
44
  */
45
  inline void AddScore(const Tree* tree, int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
46
    tree->AddPredictionToScore(data_, num_data_, score_.data() + curr_class * num_data_);
Guolin Ke's avatar
Guolin Ke committed
47
48
  }
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
49
50
51
  * \brief Adding prediction score, only used for training data.
  *        The training data is partitioned into tree leaves after training
  *        Based on which We can get prediction quckily.
Guolin Ke's avatar
Guolin Ke committed
52
  * \param tree_learner
53
  * \param curr_class Current class for multiclass training
Guolin Ke's avatar
Guolin Ke committed
54
  */
55
  inline void AddScore(const TreeLearner* tree_learner, int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
56
    tree_learner->AddPredictionToScore(score_.data() + curr_class * num_data_);
Guolin Ke's avatar
Guolin Ke committed
57
58
  }
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
59
60
  * \brief Using tree model to get prediction number, then adding to scores for parts of data
  *        Used for prediction of training out-of-bag data
Guolin Ke's avatar
Guolin Ke committed
61
  * \param tree Trained tree model
Hui Xue's avatar
Hui Xue committed
62
63
  * \param data_indices Indices of data that will be proccessed
  * \param data_cnt Number of data that will be proccessed
64
  * \param curr_class Current class for multiclass training
Guolin Ke's avatar
Guolin Ke committed
65
66
  */
  inline void AddScore(const Tree* tree, const data_size_t* data_indices,
67
                                                  data_size_t data_cnt, int curr_class) {
Guolin Ke's avatar
Guolin Ke committed
68
    tree->AddPredictionToScore(data_, data_indices, data_cnt, score_.data() + curr_class * num_data_);
Guolin Ke's avatar
Guolin Ke committed
69
70
  }
  /*! \brief Pointer of score */
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
75
76
77
  inline const score_t* score() const { return score_.data(); }
  inline const data_size_t num_data() const { return num_data_; }

  /*! \brief Disable copy */
  ScoreUpdater& operator=(const ScoreUpdater&) = delete;
  /*! \brief Disable copy */
  ScoreUpdater(const ScoreUpdater&) = delete;
Guolin Ke's avatar
Guolin Ke committed
78
79
80
81
82
private:
  /*! \brief Number of total data */
  data_size_t num_data_;
  /*! \brief Pointer of data set */
  const Dataset* data_;
83
  /*! \brief Scores for data set */
Guolin Ke's avatar
Guolin Ke committed
84
  std::vector<score_t> score_;
Guolin Ke's avatar
Guolin Ke committed
85
86
87
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
88
#endif   // LightGBM_BOOSTING_SCORE_UPDATER_HPP_