goss.hpp 6.58 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#ifndef LIGHTGBM_BOOSTING_GOSS_H_
#define LIGHTGBM_BOOSTING_GOSS_H_

8
9
10
11
#include <LightGBM/boosting.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/log.h>

Guolin Ke's avatar
Guolin Ke committed
12
#include <string>
13
#include <algorithm>
14
15
16
17
18
19
20
#include <chrono>
#include <cstdio>
#include <fstream>
#include <vector>

#include "gbdt.h"
#include "score_updater.hpp"
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24

namespace LightGBM {

class GOSS: public GBDT {
Nikita Titov's avatar
Nikita Titov committed
25
 public:
Guolin Ke's avatar
Guolin Ke committed
26
27
28
  /*!
  * \brief Constructor
  */
29
  GOSS() : GBDT() {
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
  }

  ~GOSS() {
  }

Guolin Ke's avatar
Guolin Ke committed
35
  void Init(const Config* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
36
37
            const std::vector<const Metric*>& training_metrics) override {
    GBDT::Init(config, train_data, objective_function, training_metrics);
38
    ResetGoss();
39
40
41
42
43
44
    if (objective_function_ == nullptr) {
      // use customized objective function
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      gradients_.resize(total_size, 0.0f);
      hessians_.resize(total_size, 0.0f);
    }
45
46
47
48
49
50
51
52
  }

  void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
                         const std::vector<const Metric*>& training_metrics) override {
    GBDT::ResetTrainingData(train_data, objective_function, training_metrics);
    ResetGoss();
  }

Guolin Ke's avatar
Guolin Ke committed
53
  void ResetConfig(const Config* config) override {
54
55
56
57
    GBDT::ResetConfig(config);
    ResetGoss();
  }

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) override {
    if (gradients != nullptr) {
      // use customized objective function
      CHECK(hessians != nullptr && objective_function_ == nullptr);
      size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
      #pragma omp parallel for schedule(static)
      for (size_t i = 0; i < total_size; ++i) {
        gradients_[i] = gradients[i];
        hessians_[i] = hessians[i];
      }
      return GBDT::TrainOneIter(gradients_.data(), hessians_.data());
    } else {
      CHECK(hessians == nullptr);
      return GBDT::TrainOneIter(nullptr, nullptr);
    }
  }

75
  void ResetGoss() {
Nikita Titov's avatar
Nikita Titov committed
76
    CHECK_LE(config_->top_rate + config_->other_rate, 1.0f);
Guolin Ke's avatar
Guolin Ke committed
77
78
    CHECK(config_->top_rate > 0.0f && config_->other_rate > 0.0f);
    if (config_->bagging_freq > 0 && config_->bagging_fraction != 1.0f) {
79
      Log::Fatal("Cannot use bagging in GOSS");
Guolin Ke's avatar
Guolin Ke committed
80
    }
81
    Log::Info("Using GOSS");
82
    balanced_bagging_ = false;
Guolin Ke's avatar
Guolin Ke committed
83
    bag_data_indices_.resize(num_data_);
84
    bagging_runner_.ReSize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
    bagging_rands_.clear();
    for (int i = 0;
         i < (num_data_ + bagging_rand_block_ - 1) / bagging_rand_block_; ++i) {
      bagging_rands_.emplace_back(config_->bagging_seed + i);
    }
Guolin Ke's avatar
Guolin Ke committed
90
    is_use_subset_ = false;
Guolin Ke's avatar
Guolin Ke committed
91
92
    if (config_->top_rate + config_->other_rate <= 0.5) {
      auto bag_data_cnt = static_cast<data_size_t>((config_->top_rate + config_->other_rate) * num_data_);
93
      bag_data_cnt = std::max(1, bag_data_cnt);
Guolin Ke's avatar
Guolin Ke committed
94
95
96
97
98
99
100
101
      tmp_subset_.reset(new Dataset(bag_data_cnt));
      tmp_subset_->CopyFeatureMapperFrom(train_data_);
      is_use_subset_ = true;
    }
    // flag to not bagging first
    bag_data_cnt_ = num_data_;
  }

102
  data_size_t BaggingHelper(data_size_t start, data_size_t cnt, data_size_t* buffer) override {
103
    if (cnt <= 0) {
104
105
      return 0;
    }
106
    std::vector<score_t> tmp_gradients(cnt, 0.0f);
Guolin Ke's avatar
Guolin Ke committed
107
    for (data_size_t i = 0; i < cnt; ++i) {
108
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
109
        size_t idx = static_cast<size_t>(cur_tree_id) * num_data_ + start + i;
Guolin Ke's avatar
Guolin Ke committed
110
111
        tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]);
      }
Guolin Ke's avatar
Guolin Ke committed
112
    }
Guolin Ke's avatar
Guolin Ke committed
113
114
    data_size_t top_k = static_cast<data_size_t>(cnt * config_->top_rate);
    data_size_t other_k = static_cast<data_size_t>(cnt * config_->other_rate);
Guolin Ke's avatar
Guolin Ke committed
115
    top_k = std::max(1, top_k);
Guolin Ke's avatar
Guolin Ke committed
116
    ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k - 1);
117
    score_t threshold = tmp_gradients[top_k - 1];
Guolin Ke's avatar
Guolin Ke committed
118

119
    score_t multiply = static_cast<score_t>(cnt - top_k) / other_k;
Guolin Ke's avatar
Guolin Ke committed
120
    data_size_t cur_left_cnt = 0;
121
    data_size_t cur_right_pos = cnt;
Guolin Ke's avatar
Guolin Ke committed
122
123
    data_size_t big_weight_cnt = 0;
    for (data_size_t i = 0; i < cnt; ++i) {
124
      auto cur_idx = start + i;
125
      score_t grad = 0.0f;
126
      for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
127
        size_t idx = static_cast<size_t>(cur_tree_id) * num_data_ + cur_idx;
Guolin Ke's avatar
Guolin Ke committed
128
129
130
        grad += std::fabs(gradients_[idx] * hessians_[idx]);
      }
      if (grad >= threshold) {
131
        buffer[cur_left_cnt++] = cur_idx;
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
136
137
        ++big_weight_cnt;
      } else {
        data_size_t sampled = cur_left_cnt - big_weight_cnt;
        data_size_t rest_need = other_k - sampled;
        data_size_t rest_all = (cnt - i) - (top_k - big_weight_cnt);
        double prob = (rest_need) / static_cast<double>(rest_all);
138
139
        if (bagging_rands_[cur_idx / bagging_rand_block_].NextFloat() < prob) {
          buffer[cur_left_cnt++] = cur_idx;
140
          for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
141
            size_t idx = static_cast<size_t>(cur_tree_id) * num_data_ + cur_idx;
Guolin Ke's avatar
Guolin Ke committed
142
143
144
            gradients_[idx] *= multiply;
            hessians_[idx] *= multiply;
          }
Guolin Ke's avatar
Guolin Ke committed
145
        } else {
146
          buffer[--cur_right_pos] = cur_idx;
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
152
153
154
155
        }
      }
    }
    return cur_left_cnt;
  }

  void Bagging(int iter) override {
    bag_data_cnt_ = num_data_;
    // not subsample for first iterations
Guolin Ke's avatar
Guolin Ke committed
156
    if (iter < static_cast<int>(1.0f / config_->learning_rate)) { return; }
Guolin Ke's avatar
Guolin Ke committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    auto left_cnt = bagging_runner_.Run<true>(
        num_data_,
        [=](int, data_size_t cur_start, data_size_t cur_cnt, data_size_t* left,
            data_size_t*) {
          data_size_t cur_left_count = 0;
          cur_left_count = BaggingHelper(cur_start, cur_cnt, left);
          return cur_left_count;
        },
        bag_data_indices_.data());
    bag_data_cnt_ = left_cnt;
    // set bagging data to tree learner
    if (!is_use_subset_) {
      tree_learner_->SetBaggingData(nullptr, bag_data_indices_.data(), bag_data_cnt_);
    } else {
      // get subset
      tmp_subset_->ReSize(bag_data_cnt_);
      tmp_subset_->CopySubrow(train_data_, bag_data_indices_.data(),
                              bag_data_cnt_, false);
      tree_learner_->SetBaggingData(tmp_subset_.get(), bag_data_indices_.data(),
                                    bag_data_cnt_);
    }
Guolin Ke's avatar
Guolin Ke committed
178
  }
179
180
181
182
183

 protected:
  bool GetIsConstHessian(const ObjectiveFunction*) override {
    return false;
  }
Guolin Ke's avatar
Guolin Ke committed
184
185
186
187
};

}  // namespace LightGBM
#endif   // LIGHTGBM_BOOSTING_GOSS_H_