dart.hpp 4.27 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#ifndef LIGHTGBM_BOOSTING_DART_H_
#define LIGHTGBM_BOOSTING_DART_H_

#include <LightGBM/boosting.h>
#include "score_updater.hpp"
#include "gbdt.h"

#include <cstdio>
#include <vector>
#include <string>
#include <fstream>

namespace LightGBM {
/*!
* \brief DART algorithm implementation. including Training, prediction, bagging.
*/
class DART: public GBDT {
public:
  /*!
  * \brief Constructor
  */
  DART(): GBDT() { }
  /*!
  * \brief Destructor
  */
  ~DART() { }
  /*!
  * \brief Initialization logic
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metrics
  * \param output_model_filename Filename of output model
  */
  void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics) override {
    GBDT::Init(config, train_data, object_function, training_metrics);
    shrinkage_rate_ = 1.0;
    random_for_drop_ = Random(gbdt_config_->drop_seed);
  }
  /*!
  * \brief one training iteration
  */
  bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override {
Guolin Ke's avatar
Guolin Ke committed
45
    is_update_score_cur_iter_ = false;
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
52
53
54
    GBDT::TrainOneIter(gradient, hessian, false);
    // normalize
    Normalize();
    if (is_eval) {
      return EvalAndCheckEarlyStopping();
    } else {
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
55
56
57
58
59
60
61
62

  void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics) {
    GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
    shrinkage_rate_ = 1.0;
    random_for_drop_ = Random(gbdt_config_->drop_seed);
  }

Guolin Ke's avatar
Guolin Ke committed
63
64
65
66
67
68
  /*!
  * \brief Get current training score
  * \param out_len length of returned score
  * \return training score
  */
  const score_t* GetTrainingScore(data_size_t* out_len) override {
Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
    if (!is_update_score_cur_iter_) {
      // only drop one time in one iteration
      DroppingTrees();
      is_update_score_cur_iter_ = true;
    }
Guolin Ke's avatar
Guolin Ke committed
74
75
76
    *out_len = train_score_updater_->num_data() * num_class_;
    return train_score_updater_->score();
  }
Guolin Ke's avatar
Guolin Ke committed
77

Guolin Ke's avatar
Guolin Ke committed
78
79
80
81
82
83
84
85
86
87
88
89
90
  /*!
  * \brief Get Type name of this boosting object
  */
  const char* Name() const override { return "dart"; }

private:
  /*!
  * \brief drop trees based on drop_rate
  */
  void DroppingTrees() {
    drop_index_.clear();
    // select dropping tree indexes based on drop_rate
    // if drop rate is too small, skip this step, drop one tree randomly
Guolin Ke's avatar
Guolin Ke committed
91
    if (gbdt_config_->drop_rate > kEpsilon) {
92
      for (int i = 0; i < iter_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
93
        if (random_for_drop_.NextDouble() < gbdt_config_->drop_rate) {
Guolin Ke's avatar
Guolin Ke committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
          drop_index_.push_back(i);
        }
      }
    }
    // binomial-plus-one, at least one tree will be dropped
    if (drop_index_.empty()) {
      drop_index_ = random_for_drop_.Sample(iter_, 1);
    }
    // drop trees
    for (auto i : drop_index_) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = i * num_class_ + curr_class;
        models_[curr_tree]->Shrinkage(-1.0);
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
    shrinkage_rate_ = 1.0 / (1.0 + drop_index_.size());
  }
  /*!
  * \brief normalize dropped trees
  */
  void Normalize() {
    double k = static_cast<double>(drop_index_.size());
    for (auto i : drop_index_) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = i * num_class_ + curr_class;
        // update validation score
        models_[curr_tree]->Shrinkage(shrinkage_rate_);
        for (auto& score_updater : valid_score_updater_) {
          score_updater->AddScore(models_[curr_tree].get(), curr_class);
        }
        // update training score
        models_[curr_tree]->Shrinkage(-k);
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
  }
  /*! \brief The indexes of dropping trees */
132
  std::vector<int> drop_index_;
Guolin Ke's avatar
Guolin Ke committed
133
134
  /*! \brief Random generator, used to select dropping trees */
  Random random_for_drop_;
Guolin Ke's avatar
Guolin Ke committed
135
136
  /*! \brief Flag that the score is update on current iter or not*/
  bool is_update_score_cur_iter_;
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
};

}  // namespace LightGBM
#endif   // LightGBM_BOOSTING_DART_H_