dart.hpp 3.99 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#ifndef LIGHTGBM_BOOSTING_DART_H_
#define LIGHTGBM_BOOSTING_DART_H_

#include <LightGBM/boosting.h>
#include "score_updater.hpp"
#include "gbdt.h"

#include <cstdio>
#include <vector>
#include <string>
#include <fstream>

namespace LightGBM {
/*!
* \brief DART algorithm implementation. including Training, prediction, bagging.
*/
class DART: public GBDT {
public:
  /*!
  * \brief Constructor
  */
  DART(): GBDT() { }
  /*!
  * \brief Destructor
  */
  ~DART() { }
  /*!
  * \brief Initialization logic
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metrics
  * \param output_model_filename Filename of output model
  */
  void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics) override {
    GBDT::Init(config, train_data, object_function, training_metrics);
    drop_rate_ = gbdt_config_->drop_rate;
    shrinkage_rate_ = 1.0;
    random_for_drop_ = Random(gbdt_config_->drop_seed);
  }
  /*!
  * \brief one training iteration
  */
  bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override {
Guolin Ke's avatar
Guolin Ke committed
46
    is_update_score_cur_iter_ = false;
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    GBDT::TrainOneIter(gradient, hessian, false);
    // normalize
    Normalize();
    if (is_eval) {
      return EvalAndCheckEarlyStopping();
    } else {
      return false;
    }
  }
  /*!
  * \brief Get current training score
  * \param out_len length of returned score
  * \return training score
  */
  const score_t* GetTrainingScore(data_size_t* out_len) override {
Guolin Ke's avatar
Guolin Ke committed
62
63
64
65
66
    if (!is_update_score_cur_iter_) {
      // only drop one time in one iteration
      DroppingTrees();
      is_update_score_cur_iter_ = true;
    }
Guolin Ke's avatar
Guolin Ke committed
67
68
69
    *out_len = train_score_updater_->num_data() * num_class_;
    return train_score_updater_->score();
  }
Guolin Ke's avatar
Guolin Ke committed
70

Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  /*!
  * \brief Get Type name of this boosting object
  */
  const char* Name() const override { return "dart"; }

private:
  /*!
  * \brief drop trees based on drop_rate
  */
  void DroppingTrees() {
    drop_index_.clear();
    // select dropping tree indexes based on drop_rate
    // if drop rate is too small, skip this step, drop one tree randomly
    if (drop_rate_ > kEpsilon) {
85
      for (int i = 0; i < iter_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        if (random_for_drop_.NextDouble() < drop_rate_) {
          drop_index_.push_back(i);
        }
      }
    }
    // binomial-plus-one, at least one tree will be dropped
    if (drop_index_.empty()) {
      drop_index_ = random_for_drop_.Sample(iter_, 1);
    }
    // drop trees
    for (auto i : drop_index_) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = i * num_class_ + curr_class;
        models_[curr_tree]->Shrinkage(-1.0);
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
    shrinkage_rate_ = 1.0 / (1.0 + drop_index_.size());
  }
  /*!
  * \brief normalize dropped trees
  */
  void Normalize() {
    double k = static_cast<double>(drop_index_.size());
    for (auto i : drop_index_) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = i * num_class_ + curr_class;
        // update validation score
        models_[curr_tree]->Shrinkage(shrinkage_rate_);
        for (auto& score_updater : valid_score_updater_) {
          score_updater->AddScore(models_[curr_tree].get(), curr_class);
        }
        // update training score
        models_[curr_tree]->Shrinkage(-k);
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
  }
  /*! \brief The indexes of dropping trees */
125
  std::vector<int> drop_index_;
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
  /*! \brief Dropping rate */
  double drop_rate_;
  /*! \brief Random generator, used to select dropping trees */
  Random random_for_drop_;
Guolin Ke's avatar
Guolin Ke committed
130
131
  /*! \brief Flag that the score is update on current iter or not*/
  bool is_update_score_cur_iter_;
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
};

}  // namespace LightGBM
#endif   // LightGBM_BOOSTING_DART_H_