dart.hpp 4.19 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#ifndef LIGHTGBM_BOOSTING_DART_H_
#define LIGHTGBM_BOOSTING_DART_H_

#include <LightGBM/boosting.h>
#include "score_updater.hpp"
#include "gbdt.h"

#include <cstdio>
#include <vector>
#include <string>
#include <fstream>

namespace LightGBM {
/*!
* \brief DART algorithm implementation. including Training, prediction, bagging.
*/
class DART: public GBDT {
public:
  /*!
  * \brief Constructor
  */
  DART(): GBDT() { }
  /*!
  * \brief Destructor
  */
  ~DART() { }
  /*!
  * \brief Initialization logic
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metrics
  * \param output_model_filename Filename of output model
  */
  void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics) override {
    GBDT::Init(config, train_data, object_function, training_metrics);
    drop_rate_ = gbdt_config_->drop_rate;
    shrinkage_rate_ = 1.0;
    random_for_drop_ = Random(gbdt_config_->drop_seed);
  }
  /*!
  * \brief one training iteration
  */
  bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override {
    GBDT::TrainOneIter(gradient, hessian, false);
    // normalize
    Normalize();
    if (is_eval) {
      return EvalAndCheckEarlyStopping();
    } else {
      return false;
    }
  }
  /*!
  * \brief Get current training score
  * \param out_len length of returned score
  * \return training score
  */
  const score_t* GetTrainingScore(data_size_t* out_len) override {
    DroppingTrees();
    *out_len = train_score_updater_->num_data() * num_class_;
    return train_score_updater_->score();
  }
  /*!
  * \brief save model to file
  * \param num_used_model number of model that want to save, -1 means save all
  * \param is_finish is training finished or not
  * \param filename filename that want to save to
  */
  void SaveModelToFile(int num_used_model, bool is_finish, const char* filename) override {
    // only save model once when is_finish = true
    if (is_finish && saved_model_size_ < 0) {
      GBDT::SaveModelToFile(num_used_model, is_finish, filename);
    }
  }
  /*!
  * \brief Get Type name of this boosting object
  */
  const char* Name() const override { return "dart"; }

private:
  /*!
  * \brief drop trees based on drop_rate
  */
  void DroppingTrees() {
    drop_index_.clear();
    // select dropping tree indexes based on drop_rate
    // if drop rate is too small, skip this step, drop one tree randomly
    if (drop_rate_ > kEpsilon) {
91
      for (int i = 0; i < iter_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        if (random_for_drop_.NextDouble() < drop_rate_) {
          drop_index_.push_back(i);
        }
      }
    }
    // binomial-plus-one, at least one tree will be dropped
    if (drop_index_.empty()) {
      drop_index_ = random_for_drop_.Sample(iter_, 1);
    }
    // drop trees
    for (auto i : drop_index_) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = i * num_class_ + curr_class;
        models_[curr_tree]->Shrinkage(-1.0);
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
    shrinkage_rate_ = 1.0 / (1.0 + drop_index_.size());
  }
  /*!
  * \brief normalize dropped trees
  */
  void Normalize() {
    double k = static_cast<double>(drop_index_.size());
    for (auto i : drop_index_) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = i * num_class_ + curr_class;
        // update validation score
        models_[curr_tree]->Shrinkage(shrinkage_rate_);
        for (auto& score_updater : valid_score_updater_) {
          score_updater->AddScore(models_[curr_tree].get(), curr_class);
        }
        // update training score
        models_[curr_tree]->Shrinkage(-k);
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
  }
  /*! \brief The indexes of dropping trees */
131
  std::vector<int> drop_index_;
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
136
137
138
139
  /*! \brief Dropping rate */
  double drop_rate_;
  /*! \brief Random generator, used to select dropping trees */
  Random random_for_drop_;
};

}  // namespace LightGBM
#endif   // LightGBM_BOOSTING_DART_H_