col_sampler.hpp 7.74 KB
Newer Older
1
2
3
4
5
/*!
 * Copyright (c) 2020 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
 */
6
7
#ifndef LIGHTGBM_SRC_TREELEARNER_COL_SAMPLER_HPP_
#define LIGHTGBM_SRC_TREELEARNER_COL_SAMPLER_HPP_
8
9
10
11
12
13
14

#include <LightGBM/dataset.h>
#include <LightGBM/meta.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>

15
#include <algorithm>
16
#include <unordered_set>
17
18
#include <vector>

19
20
21
namespace LightGBM {
class ColSampler {
 public:
22
  explicit ColSampler(const Config* config)
23
24
25
26
      : fraction_bytree_(config->feature_fraction),
        fraction_bynode_(config->feature_fraction_bynode),
        seed_(config->feature_fraction_seed),
        random_(config->feature_fraction_seed) {
27
28
29
30
    for (auto constraint : config->interaction_constraints_vector) {
      std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
      interaction_constraints_.push_back(constraint_set);
    }
31
32
33
  }

  static int GetCnt(size_t total_cnt, double fraction) {
34
    const int min = std::min(1, static_cast<int>(total_cnt));
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    int used_feature_cnt = static_cast<int>(Common::RoundInt(total_cnt * fraction));
    return std::max(used_feature_cnt, min);
  }

  void SetTrainingData(const Dataset* train_data) {
    train_data_ = train_data;
    is_feature_used_.resize(train_data_->num_features(), 1);
    valid_feature_indices_ = train_data->ValidFeatureIndices();
    if (fraction_bytree_ >= 1.0f) {
      need_reset_bytree_ = false;
      used_cnt_bytree_ = static_cast<int>(valid_feature_indices_.size());
    } else {
      need_reset_bytree_ = true;
      used_cnt_bytree_ =
          GetCnt(valid_feature_indices_.size(), fraction_bytree_);
    }
    ResetByTree();
  }

  void SetConfig(const Config* config) {
    fraction_bytree_ = config->feature_fraction;
    fraction_bynode_ = config->feature_fraction_bynode;
    is_feature_used_.resize(train_data_->num_features(), 1);
    // seed is changed
    if (seed_ != config->feature_fraction_seed) {
      seed_ = config->feature_fraction_seed;
      random_ = Random(seed_);
    }
    if (fraction_bytree_ >= 1.0f) {
      need_reset_bytree_ = false;
      used_cnt_bytree_ = static_cast<int>(valid_feature_indices_.size());
    } else {
      need_reset_bytree_ = true;
      used_cnt_bytree_ =
          GetCnt(valid_feature_indices_.size(), fraction_bytree_);
    }
    ResetByTree();
  }

  void ResetByTree() {
    if (need_reset_bytree_) {
      std::memset(is_feature_used_.data(), 0,
                  sizeof(int8_t) * is_feature_used_.size());
      used_feature_indices_ = random_.Sample(
          static_cast<int>(valid_feature_indices_.size()), used_cnt_bytree_);
      int omp_loop_size = static_cast<int>(used_feature_indices_.size());

82
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (omp_loop_size >= 1024)
83
84
85
86
87
88
89
90
      for (int i = 0; i < omp_loop_size; ++i) {
        int used_feature = valid_feature_indices_[used_feature_indices_[i]];
        int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
        is_feature_used_[inner_feature_index] = 1;
      }
    }
  }

91
92
93
94
95
96
97
98
99
100
101
102
  std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
    // get interaction constraints for current branch
    std::unordered_set<int> allowed_features;
    if (!interaction_constraints_.empty()) {
      std::vector<int> branch_features = tree->branch_features(leaf);
      allowed_features.insert(branch_features.begin(), branch_features.end());
      for (auto constraint : interaction_constraints_) {
        int num_feat_found = 0;
        if (branch_features.size() == 0) {
          allowed_features.insert(constraint.begin(), constraint.end());
        }
        for (int feat : branch_features) {
103
104
105
          if (constraint.count(feat) == 0) {
            break;
          }
106
107
108
109
110
111
112
          ++num_feat_found;
          if (num_feat_found == static_cast<int>(branch_features.size())) {
            allowed_features.insert(constraint.begin(), constraint.end());
            break;
          }
        }
      }
113
    }
114

115
    std::vector<int8_t> ret(train_data_->num_features(), 0);
116
117
118
119
120
121
    if (fraction_bynode_ >= 1.0f) {
      if (interaction_constraints_.empty()) {
        return std::vector<int8_t>(train_data_->num_features(), 1);
      } else {
        for (int feat : allowed_features) {
          int inner_feat = train_data_->InnerFeatureIndex(feat);
122
123
124
          if (inner_feat >= 0) {
            ret[inner_feat] = 1;
          }
125
126
127
128
        }
        return ret;
      }
    }
129
130
    if (need_reset_bytree_) {
      auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
131
132
133
134
135
136
137
138
139
140
141
142
143
      std::vector<int>* allowed_used_feature_indices;
      std::vector<int> filtered_feature_indices;
      if (interaction_constraints_.empty()) {
        allowed_used_feature_indices = &used_feature_indices_;
      } else {
        for (int feat_ind : used_feature_indices_) {
          if (allowed_features.count(valid_feature_indices_[feat_ind]) == 1) {
            filtered_feature_indices.push_back(feat_ind);
          }
        }
        used_feature_cnt = std::min(used_feature_cnt, static_cast<int>(filtered_feature_indices.size()));
        allowed_used_feature_indices = &filtered_feature_indices;
      }
144
      auto sampled_indices = random_.Sample(
145
          static_cast<int>((*allowed_used_feature_indices).size()), used_feature_cnt);
146
      int omp_loop_size = static_cast<int>(sampled_indices.size());
147
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (omp_loop_size >= 1024)
148
149
      for (int i = 0; i < omp_loop_size; ++i) {
        int used_feature =
150
            valid_feature_indices_[(*allowed_used_feature_indices)[sampled_indices[i]]];
151
152
153
154
155
156
        int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
        ret[inner_feature_index] = 1;
      }
    } else {
      auto used_feature_cnt =
          GetCnt(valid_feature_indices_.size(), fraction_bynode_);
157
158
159
160
161
162
163
164
165
166
167
168
169
      std::vector<int>* allowed_valid_feature_indices;
      std::vector<int> filtered_feature_indices;
      if (interaction_constraints_.empty()) {
        allowed_valid_feature_indices = &valid_feature_indices_;
      } else {
        for (int feat : valid_feature_indices_) {
          if (allowed_features.count(feat) == 1) {
            filtered_feature_indices.push_back(feat);
          }
        }
        allowed_valid_feature_indices = &filtered_feature_indices;
        used_feature_cnt = std::min(used_feature_cnt, static_cast<int>(filtered_feature_indices.size()));
      }
170
      auto sampled_indices = random_.Sample(
171
          static_cast<int>((*allowed_valid_feature_indices).size()), used_feature_cnt);
172
      int omp_loop_size = static_cast<int>(sampled_indices.size());
173
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (omp_loop_size >= 1024)
174
      for (int i = 0; i < omp_loop_size; ++i) {
175
        int used_feature = (*allowed_valid_feature_indices)[sampled_indices[i]];
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
        ret[inner_feature_index] = 1;
      }
    }
    return ret;
  }

  const std::vector<int8_t>& is_feature_used_bytree() const {
    return is_feature_used_;
  }

  void SetIsFeatureUsedByTree(int fid, bool val) {
    is_feature_used_[fid] = val;
  }

 private:
  const Dataset* train_data_;
  double fraction_bytree_;
  double fraction_bynode_;
  bool need_reset_bytree_;
  int used_cnt_bytree_;
  int seed_;
  Random random_;
  std::vector<int8_t> is_feature_used_;
  std::vector<int> used_feature_indices_;
  std::vector<int> valid_feature_indices_;
202
203
  /*! \brief interaction constraints index in original (raw data) features */
  std::vector<std::unordered_set<int>> interaction_constraints_;
204
205
206
};

}  // namespace LightGBM
207
#endif  // LIGHTGBM_SRC_TREELEARNER_COL_SAMPLER_HPP_