dart.hpp 6.46 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#ifndef LIGHTGBM_BOOSTING_DART_H_
#define LIGHTGBM_BOOSTING_DART_H_

#include <LightGBM/boosting.h>
#include "score_updater.hpp"
#include "gbdt.h"

#include <cstdio>
#include <vector>
#include <string>
#include <fstream>

namespace LightGBM {
/*!
* \brief DART algorithm implementation. including Training, prediction, bagging.
*/
class DART: public GBDT {
public:
  /*!
  * \brief Constructor
  */
22
  DART() : GBDT() { }
Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  /*!
  * \brief Destructor
  */
  ~DART() { }
  /*!
  * \brief Initialization logic
  * \param config Config for boosting
  * \param train_data Training data
  * \param object_function Training objective function
  * \param training_metrics Training metrics
  * \param output_model_filename Filename of output model
  */
  void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
    const std::vector<const Metric*>& training_metrics) override {
    GBDT::Init(config, train_data, object_function, training_metrics);
    random_for_drop_ = Random(gbdt_config_->drop_seed);
39
    sum_weight_ = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
40
41
42
43
44
  }
  /*!
  * \brief one training iteration
  */
  bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override {
Guolin Ke's avatar
Guolin Ke committed
45
    is_update_score_cur_iter_ = false;
Guolin Ke's avatar
Guolin Ke committed
46
47
48
    GBDT::TrainOneIter(gradient, hessian, false);
    // normalize
    Normalize();
49
50
51
52
    if (!gbdt_config_->uniform_drop) {
      tree_weight_.push_back(shrinkage_rate_);
      sum_weight_ += shrinkage_rate_;
    }
Guolin Ke's avatar
Guolin Ke committed
53
54
55
56
57
58
    if (is_eval) {
      return EvalAndCheckEarlyStopping();
    } else {
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
59

Guolin Ke's avatar
Guolin Ke committed
60
61
62
63
64
  /*!
  * \brief Get current training score
  * \param out_len length of returned score
  * \return training score
  */
65
  const score_t* GetTrainingScore(int64_t* out_len) override {
Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
70
    if (!is_update_score_cur_iter_) {
      // only drop one time in one iteration
      DroppingTrees();
      is_update_score_cur_iter_ = true;
    }
71
    *out_len = static_cast<int64_t>(train_score_updater_->num_data()) * num_class_;
Guolin Ke's avatar
Guolin Ke committed
72
73
    return train_score_updater_->score();
  }
Guolin Ke's avatar
Guolin Ke committed
74

Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
private:
  /*!
  * \brief drop trees based on drop_rate
  */
  void DroppingTrees() {
    drop_index_.clear();
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    bool is_skip = random_for_drop_.NextDouble() < gbdt_config_->skip_drop;
    // select dropping tree indexes based on drop_rate and tree weights
    if (!is_skip) {
      double drop_rate = gbdt_config_->drop_rate;
      if (!gbdt_config_->uniform_drop) {
        double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_;
        if (gbdt_config_->max_drop > 0) {
          drop_rate = std::min(drop_rate, gbdt_config_->max_drop * inv_average_weight / sum_weight_);
        }
        for (int i = 0; i < iter_; ++i) {
          if (random_for_drop_.NextDouble() < drop_rate * tree_weight_[i] * inv_average_weight) {
            drop_index_.push_back(i);
          }
        }
      } else {
        if (gbdt_config_->max_drop > 0) {
          drop_rate = std::min(drop_rate, gbdt_config_->max_drop / static_cast<double>(iter_));
        }
        for (int i = 0; i < iter_; ++i) {
          if (random_for_drop_.NextDouble() < drop_rate) {
            drop_index_.push_back(i);
          }
Guolin Ke's avatar
Guolin Ke committed
103
104
105
106
107
108
109
110
111
112
113
        }
      }
    }
    // drop trees
    for (auto i : drop_index_) {
      for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
        auto curr_tree = i * num_class_ + curr_class;
        models_[curr_tree]->Shrinkage(-1.0);
        train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
      }
    }
114
115
116
117
118
119
120
121
122
    if (!gbdt_config_->xgboost_dart_mode) {
      shrinkage_rate_ = gbdt_config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size()));
    } else {
      if (drop_index_.empty()) {
        shrinkage_rate_ = gbdt_config_->learning_rate;
      } else {
        shrinkage_rate_ = gbdt_config_->learning_rate / (gbdt_config_->learning_rate + static_cast<double>(drop_index_.size()));
      }
    }
Guolin Ke's avatar
Guolin Ke committed
123
124
125
  }
  /*!
  * \brief normalize dropped trees
126
  * NOTE: num_drop_tree(k), learning_rate(lr), shrinkage_rate_ = lr / (k + 1)
wxchan's avatar
wxchan committed
127
  *       step 1: shrink tree to -1 -> drop tree
128
  *       step 2: shrink tree to k / (k + 1) - 1 from -1, by 1/(k+1)
wxchan's avatar
wxchan committed
129
  *               -> normalize for valid data
130
  *       step 3: shrink tree to k / (k + 1) from k / (k + 1) - 1, by -k
wxchan's avatar
wxchan committed
131
  *               -> normalize for train data
132
  *       end with tree weight = (k / (k + 1)) * old_weight
Guolin Ke's avatar
Guolin Ke committed
133
134
135
  */
  void Normalize() {
    double k = static_cast<double>(drop_index_.size());
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    if (!gbdt_config_->xgboost_dart_mode) {
      for (auto i : drop_index_) {
        for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
          auto curr_tree = i * num_class_ + curr_class;
          // update validation score
          models_[curr_tree]->Shrinkage(1.0f / (k + 1.0f));
          for (auto& score_updater : valid_score_updater_) {
            score_updater->AddScore(models_[curr_tree].get(), curr_class);
          }
          // update training score
          models_[curr_tree]->Shrinkage(-k);
          train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
        }
        if (!gbdt_config_->uniform_drop) {
          sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
          tree_weight_[i] *= (k / (k + 1.0f));
        }
      }
    } else {
      for (auto i : drop_index_) {
        for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
          auto curr_tree = i * num_class_ + curr_class;
          // update validation score
          models_[curr_tree]->Shrinkage(shrinkage_rate_);
          for (auto& score_updater : valid_score_updater_) {
            score_updater->AddScore(models_[curr_tree].get(), curr_class);
          }
          // update training score
          models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate);
          train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
        }
        if (!gbdt_config_->uniform_drop) {
          sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));;
          tree_weight_[i] *= (k / (k + gbdt_config_->learning_rate));
Guolin Ke's avatar
Guolin Ke committed
170
171
172
173
        }
      }
    }
  }
174
175
176
177
  /*! \brief The weights of all trees, used to choose drop trees */
  std::vector<double> tree_weight_;
  /*! \brief sum weights of all trees */
  double sum_weight_;
Guolin Ke's avatar
Guolin Ke committed
178
  /*! \brief The indexes of dropping trees */
179
  std::vector<int> drop_index_;
Guolin Ke's avatar
Guolin Ke committed
180
181
  /*! \brief Random generator, used to select dropping trees */
  Random random_for_drop_;
Guolin Ke's avatar
Guolin Ke committed
182
183
  /*! \brief Flag that the score is update on current iter or not*/
  bool is_update_score_cur_iter_;
Guolin Ke's avatar
Guolin Ke committed
184
185
186
187
};

}  // namespace LightGBM
#endif   // LightGBM_BOOSTING_DART_H_