cost_effective_gradient_boosting.hpp 5.96 KB
Newer Older
1
2
/*!
 * Copyright (c) 2019 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
5
6
7
8
9
10
11
12
13
 */
#ifndef LIGHTGBM_TREELEARNER_COST_EFFECTIVE_GRADIENT_BOOSTING_HPP_
#define LIGHTGBM_TREELEARNER_COST_EFFECTIVE_GRADIENT_BOOSTING_HPP_

#include <LightGBM/config.h>
#include <LightGBM/dataset.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>

14
15
#include <vector>

16
17
18
19
20
21
22
23
#include "data_partition.hpp"
#include "serial_tree_learner.h"
#include "split_info.hpp"

namespace LightGBM {

class CostEfficientGradientBoosting {
 public:
24
25
  explicit CostEfficientGradientBoosting(const SerialTreeLearner* tree_learner)
      : init_(false), tree_learner_(tree_learner) {}
26
  static bool IsEnable(const Config* config) {
27
28
29
    if (config->cegb_tradeoff >= 1.0f && config->cegb_penalty_split <= 0.0f &&
        config->cegb_penalty_feature_coupled.empty() &&
        config->cegb_penalty_feature_lazy.empty()) {
30
31
32
33
34
35
36
      return false;
    } else {
      return true;
    }
  }
  void Init() {
    auto train_data = tree_learner_->train_data_;
37
38
39
40
41
42
43
    if (!init_) {
      splits_per_leaf_.resize(
          static_cast<size_t>(tree_learner_->config_->num_leaves) *
          train_data->num_features());
      is_feature_used_in_split_.clear();
      is_feature_used_in_split_.resize(train_data->num_features());
    }
44

45
46
47
48
49
50
    if (!tree_learner_->config_->cegb_penalty_feature_coupled.empty() &&
        tree_learner_->config_->cegb_penalty_feature_coupled.size() !=
            static_cast<size_t>(train_data->num_total_features())) {
      Log::Fatal(
          "cegb_penalty_feature_coupled should be the same size as feature "
          "number.");
51
52
    }
    if (!tree_learner_->config_->cegb_penalty_feature_lazy.empty()) {
53
54
55
56
57
58
59
60
61
      if (tree_learner_->config_->cegb_penalty_feature_lazy.size() !=
          static_cast<size_t>(train_data->num_total_features())) {
        Log::Fatal(
            "cegb_penalty_feature_lazy should be the same size as feature "
            "number.");
      }
      if (!init_) {
        feature_used_in_data_ = Common::EmptyBitset(train_data->num_features() *
                                                    tree_learner_->num_data_);
62
63
      }
    }
64
    init_ = true;
65
  }
66
67
  double DetlaGain(int feature_index, int real_fidx, int leaf_index,
                   int num_data_in_leaf, SplitInfo split_info) {
68
    auto config = tree_learner_->config_;
69
70
71
72
73
74
    double delta =
        config->cegb_tradeoff * config->cegb_penalty_split * num_data_in_leaf;
    if (!config->cegb_penalty_feature_coupled.empty() &&
        !is_feature_used_in_split_[feature_index]) {
      delta += config->cegb_tradeoff *
               config->cegb_penalty_feature_coupled[real_fidx];
75
76
    }
    if (!config->cegb_penalty_feature_lazy.empty()) {
77
78
      delta += config->cegb_tradeoff *
               CalculateOndemandCosts(feature_index, real_fidx, leaf_index);
79
    }
80
81
82
    splits_per_leaf_[static_cast<size_t>(leaf_index) *
                         tree_learner_->train_data_->num_features() +
                     feature_index] = split_info;
83
84
    return delta;
  }
85
86
87
  void UpdateLeafBestSplits(Tree* tree, int best_leaf,
                            const SplitInfo* best_split_info,
                            std::vector<SplitInfo>* best_split_per_leaf) {
88
89
    auto config = tree_learner_->config_;
    auto train_data = tree_learner_->train_data_;
90
91
    const int inner_feature_index =
        train_data->InnerFeatureIndex(best_split_info->feature);
Guolin Ke's avatar
Guolin Ke committed
92
    auto& ref_best_split_per_leaf = *best_split_per_leaf;
93
94
    if (!config->cegb_penalty_feature_coupled.empty() &&
        !is_feature_used_in_split_[inner_feature_index]) {
95
96
97
      is_feature_used_in_split_[inner_feature_index] = true;
      for (int i = 0; i < tree->num_leaves(); ++i) {
        if (i == best_leaf) continue;
98
99
100
101
102
103
        auto split = &splits_per_leaf_[static_cast<size_t>(i) *
                                           train_data->num_features() +
                                       inner_feature_index];
        split->gain +=
            config->cegb_tradeoff *
            config->cegb_penalty_feature_coupled[best_split_info->feature];
Guolin Ke's avatar
Guolin Ke committed
104
105
106
        // Avoid to update the leaf that cannot split
        if (ref_best_split_per_leaf[i].gain > kMinScore &&
            *split > ref_best_split_per_leaf[i]) {
Guolin Ke's avatar
Guolin Ke committed
107
          ref_best_split_per_leaf[i] = *split;
Guolin Ke's avatar
Guolin Ke committed
108
        }
109
110
111
112
      }
    }
    if (!config->cegb_penalty_feature_lazy.empty()) {
      data_size_t cnt_leaf_data = 0;
113
114
      auto tmp_idx = tree_learner_->data_partition_->GetIndexOnLeaf(
          best_leaf, &cnt_leaf_data);
115
116
      for (data_size_t i_input = 0; i_input < cnt_leaf_data; ++i_input) {
        int real_idx = tmp_idx[i_input];
117
118
119
        Common::InsertBitset(
            &feature_used_in_data_,
            train_data->num_data() * inner_feature_index + real_idx);
120
121
122
123
124
      }
    }
  }

 private:
125
126
  double CalculateOndemandCosts(int feature_index, int real_fidx,
                                int leaf_index) const {
127
128
129
130
    if (tree_learner_->config_->cegb_penalty_feature_lazy.empty()) {
      return 0.0f;
    }
    auto train_data = tree_learner_->train_data_;
131
132
    double penalty =
        tree_learner_->config_->cegb_penalty_feature_lazy[real_fidx];
133
134
135

    double total = 0.0f;
    data_size_t cnt_leaf_data = 0;
136
137
    auto tmp_idx = tree_learner_->data_partition_->GetIndexOnLeaf(
        leaf_index, &cnt_leaf_data);
138
139
140

    for (data_size_t i_input = 0; i_input < cnt_leaf_data; ++i_input) {
      int real_idx = tmp_idx[i_input];
141
142
143
144
      if (Common::FindInBitset(
              feature_used_in_data_.data(),
              train_data->num_data() * train_data->num_features(),
              train_data->num_data() * feature_index + real_idx)) {
145
146
147
148
149
150
        continue;
      }
      total += penalty;
    }
    return total;
  }
151
  bool init_;
152
153
154
155
156
157
158
159
160
  const SerialTreeLearner* tree_learner_;
  std::vector<SplitInfo> splits_per_leaf_;
  std::vector<bool> is_feature_used_in_split_;
  std::vector<uint32_t> feature_used_in_data_;
};

}  // namespace LightGBM

#endif  // LIGHTGBM_TREELEARNER_COST_EFFECTIVE_GRADIENT_BOOSTING_HPP_