dart.hpp 6.73 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#ifndef LIGHTGBM_BOOSTING_DART_H_
#define LIGHTGBM_BOOSTING_DART_H_

#include <LightGBM/boosting.h>
#include "score_updater.hpp"
#include "gbdt.h"

#include <cstdio>
#include <vector>
#include <string>
#include <fstream>

namespace LightGBM {
/*!
* \brief DART algorithm implementation. including Training, prediction, bagging.
*/
class DART: public GBDT {
public:
  /*!
  * \brief Constructor
  */
22
  DART() : GBDT() { }
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  /*!
  * \brief Destructor
  */
  ~DART() { }
  /*!
  * \brief Initialization logic
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metrics
  * \param output_model_filename Filename of output model
  */
  void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics) override {
    GBDT::Init(config, train_data, object_function, training_metrics);
    random_for_drop_ = Random(gbdt_config_->drop_seed);
39
    sum_weight_ = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
40
  }
Guolin Ke's avatar
Guolin Ke committed
41
42
43
44
45

  void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics) override {
    GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
  }
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
  /*!
  * \brief one training iteration
  */
  bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override {
Guolin Ke's avatar
Guolin Ke committed
50
    is_update_score_cur_iter_ = false;
Guolin Ke's avatar
Guolin Ke committed
51
52
53
    GBDT::TrainOneIter(gradient, hessian, false);
    // normalize
    Normalize();
54
55
56
57
    if (!gbdt_config_->uniform_drop) {
      tree_weight_.push_back(shrinkage_rate_);
      sum_weight_ += shrinkage_rate_;
    }
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
    if (is_eval) {
      return EvalAndCheckEarlyStopping();
    } else {
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
64

Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
69
  /*!
  * \brief Get current training score
  * \param out_len length of returned score
  * \return training score
  */
70
  const double* GetTrainingScore(int64_t* out_len) override {
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
75
    if (!is_update_score_cur_iter_) {
      // only drop one time in one iteration
      DroppingTrees();
      is_update_score_cur_iter_ = true;
    }
76
    *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
77
78
    return train_score_updater_->score();
  }
Guolin Ke's avatar
Guolin Ke committed
79

Guolin Ke's avatar
Guolin Ke committed
80
81
82
83
84
85
private:
  /*!
  * \brief drop trees based on drop_rate
  */
  void DroppingTrees() {
    drop_index_.clear();
Guolin Ke's avatar
Guolin Ke committed
86
    bool is_skip = random_for_drop_.NextFloat() < gbdt_config_->skip_drop;
zhangyafeikimi's avatar
zhangyafeikimi committed
87
    // select dropping tree indices based on drop_rate and tree weights
88
89
90
91
92
93
94
95
    if (!is_skip) {
      double drop_rate = gbdt_config_->drop_rate;
      if (!gbdt_config_->uniform_drop) {
        double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_;
        if (gbdt_config_->max_drop > 0) {
          drop_rate = std::min(drop_rate, gbdt_config_->max_drop * inv_average_weight / sum_weight_);
        }
        for (int i = 0; i < iter_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
96
          if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
97
98
99
100
101
102
103
104
            drop_index_.push_back(i);
          }
        }
      } else {
        if (gbdt_config_->max_drop > 0) {
          drop_rate = std::min(drop_rate, gbdt_config_->max_drop / static_cast<double>(iter_));
        }
        for (int i = 0; i < iter_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
105
          if (random_for_drop_.NextFloat() < drop_rate) {
106
107
            drop_index_.push_back(i);
          }
Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
112
113
114
115
116
117
118
        }
      }
    }
    // drop trees
    for (auto i : drop_index_) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = i * num_class_ + curr_class;
        models_[curr_tree]->Shrinkage(-1.0);
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
119
120
121
122
123
124
125
126
127
    if (!gbdt_config_->xgboost_dart_mode) {
      shrinkage_rate_ = gbdt_config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size()));
    } else {
      if (drop_index_.empty()) {
        shrinkage_rate_ = gbdt_config_->learning_rate;
      } else {
        shrinkage_rate_ = gbdt_config_->learning_rate / (gbdt_config_->learning_rate + static_cast<double>(drop_index_.size()));
      }
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130
  }
  /*!
  * \brief normalize dropped trees
131
  * NOTE: num_drop_tree(k), learning_rate(lr), shrinkage_rate_ = lr / (k + 1)
wxchan's avatar
wxchan committed
132
  *       step 1: shrink tree to -1 -> drop tree
133
  *       step 2: shrink tree to k / (k + 1) - 1 from -1, by 1/(k+1)
wxchan's avatar
wxchan committed
134
  *               -> normalize for valid data
135
  *       step 3: shrink tree to k / (k + 1) from k / (k + 1) - 1, by -k
wxchan's avatar
wxchan committed
136
  *               -> normalize for train data
137
  *       end with tree weight = (k / (k + 1)) * old_weight
Guolin Ke's avatar
Guolin Ke committed
138
139
140
  */
  void Normalize() {
    double k = static_cast<double>(drop_index_.size());
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    if (!gbdt_config_->xgboost_dart_mode) {
      for (auto i : drop_index_) {
        for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
          auto curr_tree = i * num_class_ + curr_class;
          // update validation score
          models_[curr_tree]->Shrinkage(1.0f / (k + 1.0f));
          for (auto& score_updater : valid_score_updater_) {
            score_updater->AddScore(models_[curr_tree].get(), curr_class);
          }
          // update training score
          models_[curr_tree]->Shrinkage(-k);
          train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
        }
        if (!gbdt_config_->uniform_drop) {
          sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
          tree_weight_[i] *= (k / (k + 1.0f));
        }
      }
    } else {
      for (auto i : drop_index_) {
        for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
          auto curr_tree = i * num_class_ + curr_class;
          // update validation score
          models_[curr_tree]->Shrinkage(shrinkage_rate_);
          for (auto& score_updater : valid_score_updater_) {
            score_updater->AddScore(models_[curr_tree].get(), curr_class);
          }
          // update training score
          models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate);
          train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
        }
        if (!gbdt_config_->uniform_drop) {
          sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));;
          tree_weight_[i] *= (k / (k + gbdt_config_->learning_rate));
Guolin Ke's avatar
Guolin Ke committed
175
176
177
178
        }
      }
    }
  }
179
180
181
182
  /*! \brief The weights of all trees, used to choose drop trees */
  std::vector<double> tree_weight_;
  /*! \brief sum weights of all trees */
  double sum_weight_;
zhangyafeikimi's avatar
zhangyafeikimi committed
183
  /*! \brief The indices of dropping trees */
184
  std::vector<int> drop_index_;
Guolin Ke's avatar
Guolin Ke committed
185
186
  /*! \brief Random generator, used to select dropping trees */
  Random random_for_drop_;
Guolin Ke's avatar
Guolin Ke committed
187
188
  /*! \brief Flag that the score is update on current iter or not*/
  bool is_update_score_cur_iter_;
Guolin Ke's avatar
Guolin Ke committed
189
190
191
192
};

}  // namespace LightGBM
#endif   // LightGBM_BOOSTING_DART_H_