linear_tree_learner.cpp 16.6 KB
Newer Older
1
/*!
2
 * Copyright (c) 2020 Microsoft Corporation. All rights reserved.
3
4
5
6
7
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
#include "linear_tree_learner.h"

#include <Eigen/Dense>
8
9

#include <algorithm>
10
11
12

namespace LightGBM {

13
14
15
16
template <typename TREE_LEARNER_TYPE>
void LinearTreeLearner<TREE_LEARNER_TYPE>::Init(const Dataset* train_data, bool is_constant_hessian) {
  TREE_LEARNER_TYPE::Init(train_data, is_constant_hessian);
  LinearTreeLearner::InitLinear(train_data, this->config_->num_leaves);
17
18
}

19
20
template <typename TREE_LEARNER_TYPE>
void LinearTreeLearner<TREE_LEARNER_TYPE>::InitLinear(const Dataset* train_data, const int max_leaves) {
21
22
23
  leaf_map_ = std::vector<int>(train_data->num_data(), -1);
  contains_nan_ = std::vector<int8_t>(train_data->num_features(), 0);
  // identify features containing nans
24
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
25
  for (int feat = 0; feat < train_data->num_features(); ++feat) {
26
    auto bin_mapper = this->train_data_->FeatureBinMapper(feat);
27
    if (bin_mapper->bin_type() == BinType::NumericalBin) {
28
      const float* feat_ptr = this->train_data_->raw_index(feat);
29
30
31
32
33
34
35
36
      for (int i = 0; i < train_data->num_data(); ++i) {
        if (std::isnan(feat_ptr[i])) {
          contains_nan_[feat] = 1;
          break;
        }
      }
    }
  }
37
  any_nan_ = false;
38
39
40
41
42
43
44
  for (int feat = 0; feat < train_data->num_features(); ++feat) {
    if (contains_nan_[feat]) {
      any_nan_ = true;
      break;
    }
  }
  // preallocate the matrix used to calculate linear model coefficients
45
  int max_num_feat = std::min(max_leaves, this->train_data_->num_numeric_features());
46
47
48
49
50
51
  XTHX_.clear();
  XTg_.clear();
  for (int i = 0; i < max_leaves; ++i) {
    // store only upper triangular half of matrix as an array, in row-major order
    // this requires (max_num_feat + 1) * (max_num_feat + 2) / 2 entries (including the constant terms of the regression)
    // we add another 8 to ensure cache lines are not shared among processors
52
53
    XTHX_.push_back(std::vector<double>((max_num_feat + 1) * (max_num_feat + 2) / 2 + 8, 0));
    XTg_.push_back(std::vector<double>(max_num_feat + 9, 0.0));
54
55
56
  }
  XTHX_by_thread_.clear();
  XTg_by_thread_.clear();
57
  int max_threads = OMP_NUM_THREADS();
58
59
60
61
62
63
  for (int i = 0; i < max_threads; ++i) {
    XTHX_by_thread_.push_back(XTHX_);
    XTg_by_thread_.push_back(XTg_);
  }
}

64
65
template <typename TREE_LEARNER_TYPE>
Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree) {
66
  Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer);
67
68
  this->gradients_ = gradients;
  this->hessians_ = hessians;
69
  int num_threads = OMP_NUM_THREADS();
70
  if (this->share_state_->num_threads != num_threads && this->share_state_->num_threads > 0) {
71
72
73
    Log::Warning(
        "Detected that num_threads changed during training (from %d to %d), "
        "it may cause unexpected errors.",
74
        this->share_state_->num_threads, num_threads);
75
  }
76
  this->share_state_->num_threads = num_threads;
77
78

  // some initial works before training
79
  this->BeforeTrain();
80

81
  auto tree = std::unique_ptr<Tree>(new Tree(this->config_->num_leaves, true, true));
82
  auto tree_ptr = tree.get();
83
  this->constraints_->ShareTreePointer(tree_ptr);
84
85
86
87
88
89
90

  // root leaf
  int left_leaf = 0;
  int cur_depth = 1;
  // only root leaf can be splitted on first time
  int right_leaf = -1;

91
  int init_splits = this->ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth);
92

93
  for (int split = init_splits; split < this->config_->num_leaves - 1; ++split) {
94
    // some initial works before finding best split
95
    if (this->BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) {
96
      // find best threshold for every feature
97
      this->FindBestSplits(tree_ptr);
98
99
    }
    // Get a leaf with max split gain
100
    int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(this->best_split_per_leaf_));
101
    // Get split information for best leaf
102
    const SplitInfo& best_leaf_SplitInfo = this->best_split_per_leaf_[best_leaf];
103
104
105
106
107
108
    // cannot split, quit
    if (best_leaf_SplitInfo.gain <= 0.0) {
      Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain);
      break;
    }
    // split tree with best leaf
109
    this->Split(tree_ptr, best_leaf, &left_leaf, &right_leaf);
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
  }

  bool has_nan = false;
  if (any_nan_) {
    for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
      if (contains_nan_[tree_ptr->split_feature_inner(i)]) {
        has_nan = true;
        break;
      }
    }
  }

  GetLeafMap(tree_ptr);

  if (has_nan) {
126
    CalculateLinear<true>(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree);
127
  } else {
128
    CalculateLinear<false>(tree_ptr, false, this->gradients_, this->hessians_, is_first_tree);
129
130
  }

131
  Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth);
132
133
134
  return tree.release();
}

135
136
137
template <typename TREE_LEARNER_TYPE>
Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const {
  auto tree = TREE_LEARNER_TYPE::FitByExistingTree(old_tree, gradients, hessians);
138
139
140
  bool has_nan = false;
  if (any_nan_) {
    for (int i = 0; i < tree->num_leaves() - 1 ; ++i) {
141
      if (contains_nan_[this->train_data_->InnerFeatureIndex(tree->split_feature(i))]) {
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        has_nan = true;
        break;
      }
    }
  }
  GetLeafMap(tree);
  if (has_nan) {
    CalculateLinear<true>(tree, true, gradients, hessians, false);
  } else {
    CalculateLinear<false>(tree, true, gradients, hessians, false);
  }
  return tree;
}

156
157
template <typename TREE_LEARNER_TYPE>
Tree* LinearTreeLearner<TREE_LEARNER_TYPE>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
158
                                           const score_t* gradients, const score_t *hessians) const {
159
  this->data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves());
160
161
162
  return LinearTreeLearner::FitByExistingTree(old_tree, gradients, hessians);
}

163
164
template <typename TREE_LEARNER_TYPE>
void LinearTreeLearner<TREE_LEARNER_TYPE>::GetLeafMap(Tree* tree) const {
165
166
  std::fill(leaf_map_.begin(), leaf_map_.end(), -1);
  // map data to leaf number
167
  const data_size_t* ind = this->data_partition_->indices();
168
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(dynamic)
169
  for (int i = 0; i < tree->num_leaves(); ++i) {
170
171
    data_size_t idx = this->data_partition_->leaf_begin(i);
    for (int j = 0; j < this->data_partition_->leaf_count(i); ++j) {
172
173
174
175
176
      leaf_map_[ind[idx + j]] = i;
    }
  }
}

177

178
179
180
template<typename TREE_LEARNER_TYPE>
template <bool HAS_NAN>
void LinearTreeLearner<TREE_LEARNER_TYPE>::CalculateLinear(Tree* tree, bool is_refit, const score_t* gradients, const score_t* hessians, bool is_first_tree) const {
181
182
183
184
185
186
187
188
189
190
  tree->SetIsLinear(true);
  int num_leaves = tree->num_leaves();
  int num_threads = OMP_NUM_THREADS();
  if (is_first_tree) {
    for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
      tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num));
    }
    return;
  }

191
  // calculate coefficients using the method described in Eq 3 of https://arxiv.org/abs/1802.05640
192
193
194
195
196
197
198
199
200
201
202
203
204
  // the coefficients vector is given by
  // - (X_T * H * X + lambda) ^ (-1) * (X_T * g)
  // where:
  // X is the matrix where the first column is the feature values and the second is all ones,
  // H is the diagonal matrix of the hessian,
  // lambda is the diagonal matrix with diagonal entries equal to the regularisation term linear_lambda
  // g is the vector of gradients
  // the subscript _T denotes the transpose

  // create array of pointers to raw data, and coefficient matrices, for each leaf
  std::vector<std::vector<int>> leaf_features;
  std::vector<int> leaf_num_features;
  std::vector<std::vector<const float*>> raw_data_ptr;
205
  size_t max_num_features = 0;
206
207
208
209
210
211
212
213
214
215
216
217
218
  for (int i = 0; i < num_leaves; ++i) {
    std::vector<int> raw_features;
    if (is_refit) {
      raw_features = tree->LeafFeatures(i);
    } else {
      raw_features = tree->branch_features(i);
    }
    std::sort(raw_features.begin(), raw_features.end());
    auto new_end = std::unique(raw_features.begin(), raw_features.end());
    raw_features.erase(new_end, raw_features.end());
    std::vector<int> numerical_features;
    std::vector<const float*> data_ptr;
    for (size_t j = 0; j < raw_features.size(); ++j) {
219
220
      int feat = this->train_data_->InnerFeatureIndex(raw_features[j]);
      auto bin_mapper = this->train_data_->FeatureBinMapper(feat);
221
222
      if (bin_mapper->bin_type() == BinType::NumericalBin) {
        numerical_features.push_back(feat);
223
        data_ptr.push_back(this->train_data_->raw_index(feat));
224
225
226
227
      }
    }
    leaf_features.push_back(numerical_features);
    raw_data_ptr.push_back(data_ptr);
228
229
    leaf_num_features.push_back(static_cast<int>(numerical_features.size()));
    if (numerical_features.size() > max_num_features) {
230
231
232
233
      max_num_features = numerical_features.size();
    }
  }
  // clear the coefficient matrices
234
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
235
236
  for (int i = 0; i < num_threads; ++i) {
    for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
237
238
239
      size_t num_feat = leaf_features[leaf_num].size();
      std::fill(XTHX_by_thread_[i][leaf_num].begin(), XTHX_by_thread_[i][leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0.0f);
      std::fill(XTg_by_thread_[i][leaf_num].begin(), XTg_by_thread_[i][leaf_num].begin() + num_feat + 1, 0.0f);
240
241
    }
  }
242
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
243
  for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
244
245
246
    size_t num_feat = leaf_features[leaf_num].size();
    std::fill(XTHX_[leaf_num].begin(), XTHX_[leaf_num].begin() + (num_feat + 1) * (num_feat + 2) / 2, 0.0f);
    std::fill(XTg_[leaf_num].begin(), XTg_[leaf_num].begin() + num_feat + 1, 0.0f);
247
248
249
250
251
252
253
254
  }
  std::vector<std::vector<int>> num_nonzero;
  for (int i = 0; i < num_threads; ++i) {
    if (HAS_NAN) {
      num_nonzero.push_back(std::vector<int>(num_leaves, 0));
    }
  }
  OMP_INIT_EX();
255
#pragma omp parallel num_threads(OMP_NUM_THREADS()) if (this->num_data_ > 1024)
256
257
258
259
  {
    std::vector<float> curr_row(max_num_features + 1);
    int tid = omp_get_thread_num();
#pragma omp for schedule(static)
260
    for (int i = 0; i < this->num_data_; ++i) {
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
      OMP_LOOP_EX_BEGIN();
      int leaf_num = leaf_map_[i];
      if (leaf_num < 0) {
        continue;
      }
      bool nan_found = false;
      int num_feat = leaf_num_features[leaf_num];
      for (int feat = 0; feat < num_feat; ++feat) {
        if (HAS_NAN) {
          float val = raw_data_ptr[leaf_num][feat][i];
          if (std::isnan(val)) {
            nan_found = true;
            break;
          }
          num_nonzero[tid][leaf_num] += 1;
          curr_row[feat] = val;
        } else {
          curr_row[feat] = raw_data_ptr[leaf_num][feat][i];
        }
      }
      if (HAS_NAN) {
        if (nan_found) {
          continue;
        }
      }
      curr_row[num_feat] = 1.0;
287
288
      float h = static_cast<float>(hessians[i]);
      float g = static_cast<float>(gradients[i]);
289
290
      int j = 0;
      for (int feat1 = 0; feat1 < num_feat + 1; ++feat1) {
291
        double f1_val = static_cast<double>(curr_row[feat1]);
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        XTg_by_thread_[tid][leaf_num][feat1] += f1_val * g;
        f1_val *= h;
        for (int feat2 = feat1; feat2 < num_feat + 1; ++feat2) {
          XTHX_by_thread_[tid][leaf_num][j] += f1_val * curr_row[feat2];
          ++j;
        }
      }
      OMP_LOOP_EX_END();
    }
  }
  OMP_THROW_EX();
  auto total_nonzero = std::vector<int>(tree->num_leaves());
  // aggregate results from different threads
  for (int tid = 0; tid < num_threads; ++tid) {
306
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
307
    for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
308
309
      size_t num_feat = leaf_features[leaf_num].size();
      for (size_t j = 0; j < (num_feat + 1) * (num_feat + 2) / 2; ++j) {
310
311
        XTHX_[leaf_num][j] += XTHX_by_thread_[tid][leaf_num][j];
      }
312
      for (size_t feat1 = 0; feat1 < num_feat + 1; ++feat1) {
313
314
315
316
317
318
319
320
321
        XTg_[leaf_num][feat1] += XTg_by_thread_[tid][leaf_num][feat1];
      }
      if (HAS_NAN) {
        total_nonzero[leaf_num] += num_nonzero[tid][leaf_num];
      }
    }
  }
  if (!HAS_NAN) {
    for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
322
      total_nonzero[leaf_num] = this->data_partition_->leaf_count(leaf_num);
323
324
325
    }
  }
  double shrinkage = tree->shrinkage();
326
  double decay_rate = this->config_->refit_decay_rate;
327
  // copy into eigen matrices and solve
328
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
329
330
331
332
333
334
335
336
337
338
339
340
  for (int leaf_num = 0; leaf_num < num_leaves; ++leaf_num) {
    if (total_nonzero[leaf_num] < static_cast<int>(leaf_features[leaf_num].size()) + 1) {
      if (is_refit) {
        double old_const = tree->LeafConst(leaf_num);
        tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * tree->LeafOutput(leaf_num) * shrinkage);
        tree->SetLeafCoeffs(leaf_num, std::vector<double>(leaf_features[leaf_num].size(), 0));
        tree->SetLeafFeaturesInner(leaf_num, leaf_features[leaf_num]);
      } else {
        tree->SetLeafConst(leaf_num, tree->LeafOutput(leaf_num));
      }
      continue;
    }
341
    size_t num_feat = leaf_features[leaf_num].size();
342
343
    Eigen::MatrixXd XTHX_mat(num_feat + 1, num_feat + 1);
    Eigen::MatrixXd XTg_mat(num_feat + 1, 1);
344
345
346
    size_t j = 0;
    for (size_t feat1 = 0; feat1 < num_feat + 1; ++feat1) {
      for (size_t feat2 = feat1; feat2 < num_feat + 1; ++feat2) {
347
348
349
        XTHX_mat(feat1, feat2) = XTHX_[leaf_num][j];
        XTHX_mat(feat2, feat1) = XTHX_mat(feat1, feat2);
        if ((feat1 == feat2) && (feat1 < num_feat)) {
350
          XTHX_mat(feat1, feat2) += this->config_->linear_lambda;
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        }
        ++j;
      }
      XTg_mat(feat1) = XTg_[leaf_num][feat1];
    }
    Eigen::MatrixXd coeffs = - XTHX_mat.fullPivLu().inverse() * XTg_mat;
    std::vector<double> coeffs_vec;
    std::vector<int> features_new;
    std::vector<double> old_coeffs = tree->LeafCoeffs(leaf_num);
    for (size_t i = 0; i < leaf_features[leaf_num].size(); ++i) {
      if (is_refit) {
        features_new.push_back(leaf_features[leaf_num][i]);
        coeffs_vec.push_back(decay_rate * old_coeffs[i] + (1.0 - decay_rate) * coeffs(i) * shrinkage);
      } else {
        if (coeffs(i) < -kZeroThreshold || coeffs(i) > kZeroThreshold) {
          coeffs_vec.push_back(coeffs(i));
          int feat = leaf_features[leaf_num][i];
          features_new.push_back(feat);
        }
      }
    }
    // update the tree properties
    tree->SetLeafFeaturesInner(leaf_num, features_new);
    std::vector<int> features_raw(features_new.size());
    for (size_t i = 0; i < features_new.size(); ++i) {
376
      features_raw[i] = this->train_data_->RealFeatureIndex(features_new[i]);
377
378
379
380
381
382
383
384
385
386
387
    }
    tree->SetLeafFeatures(leaf_num, features_raw);
    tree->SetLeafCoeffs(leaf_num, coeffs_vec);
    if (is_refit) {
      double old_const = tree->LeafConst(leaf_num);
      tree->SetLeafConst(leaf_num, decay_rate * old_const + (1.0 - decay_rate) * coeffs(num_feat) * shrinkage);
    } else {
      tree->SetLeafConst(leaf_num, coeffs(num_feat));
    }
  }
}
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402

template void LinearTreeLearner<SerialTreeLearner>::Init(const Dataset* train_data, bool is_constant_hessian);
template void LinearTreeLearner<SerialTreeLearner>::InitLinear(const Dataset* train_data, const int max_leaves);
template Tree* LinearTreeLearner<SerialTreeLearner>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree);
template Tree* LinearTreeLearner<SerialTreeLearner>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const;
template Tree* LinearTreeLearner<SerialTreeLearner>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
                                           const score_t* gradients, const score_t *hessians) const;

template void LinearTreeLearner<GPUTreeLearner>::Init(const Dataset* train_data, bool is_constant_hessian);
template void LinearTreeLearner<GPUTreeLearner>::InitLinear(const Dataset* train_data, const int max_leaves);
template Tree* LinearTreeLearner<GPUTreeLearner>::Train(const score_t* gradients, const score_t *hessians, bool is_first_tree);
template Tree* LinearTreeLearner<GPUTreeLearner>::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const;
template Tree* LinearTreeLearner<GPUTreeLearner>::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
                                           const score_t* gradients, const score_t *hessians) const;

403
}  // namespace LightGBM