tree.h 21.2 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#ifndef LIGHTGBM_TREE_H_
#define LIGHTGBM_TREE_H_

8
9
10
#include <LightGBM/dataset.h>
#include <LightGBM/meta.h>

Guolin Ke's avatar
Guolin Ke committed
11
#include <string>
12
#include <map>
13
14
15
#include <memory>
#include <unordered_map>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
16
17
18

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
19
20
#define kCategoricalMask (1)
#define kDefaultLeftMask (2)
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
25

/*!
* \brief Tree model
*/
class Tree {
Nikita Titov's avatar
Nikita Titov committed
26
 public:
Guolin Ke's avatar
Guolin Ke committed
27
28
29
  /*!
  * \brief Constructor
  * \param max_leaves The number of max leaves
30
  * \param track_branch_features Whether to keep track of ancestors of leaf nodes
Guolin Ke's avatar
Guolin Ke committed
31
  */
32
  explicit Tree(int max_leaves, bool track_branch_features);
Guolin Ke's avatar
Guolin Ke committed
33
34

  /*!
35
  * \brief Constructor, from a string
Guolin Ke's avatar
Guolin Ke committed
36
  * \param str Model string
37
  * \param used_len used count of str
Guolin Ke's avatar
Guolin Ke committed
38
  */
39
  Tree(const char* str, size_t* used_len);
Guolin Ke's avatar
Guolin Ke committed
40
41
42
43

  ~Tree();

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
44
45
46
  * \brief Performing a split on tree leaves.
  * \param leaf Index of leaf to be split
  * \param feature Index of feature; the converted index after removing useless features
Guolin Ke's avatar
Guolin Ke committed
47
  * \param real_feature Index of feature, the original index on data
48
  * \param threshold_bin Threshold(bin) of split
49
  * \param threshold_double Threshold on feature value
Guolin Ke's avatar
Guolin Ke committed
50
51
  * \param left_value Model Left child output
  * \param right_value Model Right child output
Guolin Ke's avatar
Guolin Ke committed
52
53
  * \param left_cnt Count of left child
  * \param right_cnt Count of right child
54
55
  * \param left_weight Weight of left child
  * \param right_weight Weight of right child
Guolin Ke's avatar
Guolin Ke committed
56
  * \param gain Split gain
Guolin Ke's avatar
Guolin Ke committed
57
58
  * \param missing_type missing type
  * \param default_left default direction for missing value
Guolin Ke's avatar
Guolin Ke committed
59
60
  * \return The index of new leaf.
  */
61
62
  int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
            double threshold_double, double left_value, double right_value,
63
64
            int left_cnt, int right_cnt, double left_weight, double right_weight,
            float gain, MissingType missing_type, bool default_left);
Guolin Ke's avatar
Guolin Ke committed
65

66
67
68
69
70
71
72
  /*!
  * \brief Performing a split on tree leaves, with categorical feature
  * \param leaf Index of leaf to be split
  * \param feature Index of feature; the converted index after removing useless features
  * \param real_feature Index of feature, the original index on data
  * \param threshold_bin Threshold(bin) of split, use bitset to represent
  * \param num_threshold_bin size of threshold_bin
73
74
  * \param threshold Thresholds of real feature value, use bitset to represent
  * \param num_threshold size of threshold
75
76
77
78
  * \param left_value Model Left child output
  * \param right_value Model Right child output
  * \param left_cnt Count of left child
  * \param right_cnt Count of right child
79
80
  * \param left_weight Weight of left child
  * \param right_weight Weight of right child
81
82
83
  * \param gain Split gain
  * \return The index of new leaf.
  */
84
85
  int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                       const uint32_t* threshold, int num_threshold, double left_value, double right_value,
86
                       int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type);
87

Guolin Ke's avatar
Guolin Ke committed
88
  /*! \brief Get the output of one leaf */
89
  inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
Guolin Ke's avatar
Guolin Ke committed
90

Guolin Ke's avatar
Guolin Ke committed
91
92
  /*! \brief Set the output of one leaf */
  inline void SetLeafOutput(int leaf, double output) {
93
    leaf_value_[leaf] = MaybeRoundToZero(output);
Guolin Ke's avatar
Guolin Ke committed
94
95
  }

Guolin Ke's avatar
Guolin Ke committed
96
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
97
  * \brief Adding prediction value of this tree model to scores
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
  * \param data The dataset
  * \param num_data Number of total data
  * \param score Will add prediction to score
  */
102
103
104
  void AddPredictionToScore(const Dataset* data,
                            data_size_t num_data,
                            double* score) const;
Guolin Ke's avatar
Guolin Ke committed
105
106

  /*!
107
  * \brief Adding prediction value of this tree model to scores
Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
112
113
  * \param data The dataset
  * \param used_data_indices Indices of used data
  * \param num_data Number of total data
  * \param score Will add prediction to score
  */
  void AddPredictionToScore(const Dataset* data,
Qiwei Ye's avatar
Qiwei Ye committed
114
                            const data_size_t* used_data_indices,
115
                            data_size_t num_data, double* score) const;
Guolin Ke's avatar
Guolin Ke committed
116

117
118
119
120
121
122
123
124
125
126
  /*!
  * \brief Get upper bound leaf value of this tree model
  */
  double GetUpperBoundValue() const;

  /*!
  * \brief Get lower bound leaf value of this tree model
  */
  double GetLowerBoundValue() const;

Guolin Ke's avatar
Guolin Ke committed
127
  /*!
128
  * \brief Prediction on one record
Guolin Ke's avatar
Guolin Ke committed
129
130
131
  * \param feature_values Feature value of this record
  * \return Prediction result
  */
132
  inline double Predict(const double* feature_values) const;
133
  inline double PredictByMap(const std::unordered_map<int, double>& feature_values) const;
134

135
  inline int PredictLeafIndex(const double* feature_values) const;
136
137
  inline int PredictLeafIndexByMap(const std::unordered_map<int, double>& feature_values) const;

138
  inline void PredictContrib(const double* feature_values, int num_features, double* output);
139
140
  inline void PredictContribByMap(const std::unordered_map<int, double>& feature_values,
                                  int num_features, std::unordered_map<int, double>* output);
141

Guolin Ke's avatar
Guolin Ke committed
142
143
144
  /*! \brief Get Number of leaves*/
  inline int num_leaves() const { return num_leaves_; }

Guolin Ke's avatar
Guolin Ke committed
145
146
147
  /*! \brief Get depth of specific leaf*/
  inline int leaf_depth(int leaf_idx) const { return leaf_depth_[leaf_idx]; }

Belinda Trotta's avatar
Belinda Trotta committed
148
149
150
  /*! \brief Get parent of specific leaf*/
  inline int leaf_parent(int leaf_idx) const {return leaf_parent_[leaf_idx]; }

wxchan's avatar
wxchan committed
151
  /*! \brief Get feature of specific split*/
Guolin Ke's avatar
Guolin Ke committed
152
  inline int split_feature(int split_idx) const { return split_feature_[split_idx]; }
wxchan's avatar
wxchan committed
153

154
155
156
  /*! \brief Get features on leaf's branch*/
  inline std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }

Guolin Ke's avatar
Guolin Ke committed
157
158
  inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
  inline double internal_value(int node_idx) const {
    return internal_value_[node_idx];
  }

  inline bool IsNumericalSplit(int node_idx) const {
    return !GetDecisionType(decision_type_[node_idx], kCategoricalMask);
  }

  inline int left_child(int node_idx) const { return left_child_[node_idx]; }

  inline int right_child(int node_idx) const { return right_child_[node_idx]; }

  inline int split_feature_inner(int node_idx) const {
    return split_feature_inner_[node_idx];
  }

  inline uint32_t threshold_in_bin(int node_idx) const {
    return threshold_in_bin_[node_idx];
  }

179
  /*! \brief Get the number of data points that fall at or below this node*/
Guolin Ke's avatar
Guolin Ke committed
180
  inline int data_count(int node) const { return node >= 0 ? internal_count_[node] : leaf_count_[~node]; }
181

Guolin Ke's avatar
Guolin Ke committed
182
183
  /*!
  * \brief Shrinkage for the tree's output
184
  *        shrinkage rate (a.k.a learning rate) is used to tune the training process
Guolin Ke's avatar
Guolin Ke committed
185
186
  * \param rate The factor of shrinkage
  */
187
  inline void Shrinkage(double rate) {
188
189
190
191
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
    for (int i = 0; i < num_leaves_ - 1; ++i) {
      leaf_value_[i] = MaybeRoundToZero(leaf_value_[i] * rate);
      internal_value_[i] = MaybeRoundToZero(internal_value_[i] * rate);
Guolin Ke's avatar
Guolin Ke committed
192
    }
193
194
    leaf_value_[num_leaves_ - 1] =
        MaybeRoundToZero(leaf_value_[num_leaves_ - 1] * rate);
Guolin Ke's avatar
Guolin Ke committed
195
    shrinkage_ *= rate;
Guolin Ke's avatar
Guolin Ke committed
196
197
  }

198
  inline double shrinkage() const { return shrinkage_; }
199

Guolin Ke's avatar
Guolin Ke committed
200
  inline void AddBias(double val) {
201
202
203
204
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
    for (int i = 0; i < num_leaves_ - 1; ++i) {
      leaf_value_[i] = MaybeRoundToZero(leaf_value_[i] + val);
      internal_value_[i] = MaybeRoundToZero(internal_value_[i] + val);
Guolin Ke's avatar
Guolin Ke committed
205
    }
206
207
    leaf_value_[num_leaves_ - 1] =
        MaybeRoundToZero(leaf_value_[num_leaves_ - 1] + val);
Guolin Ke's avatar
Guolin Ke committed
208
209
210
211
    // force to 1.0
    shrinkage_ = 1.0f;
  }

212
213
214
215
216
217
  inline void AsConstantTree(double val) {
    num_leaves_ = 1;
    shrinkage_ = 1.0f;
    leaf_value_[0] = val;
  }

wxchan's avatar
wxchan committed
218
  /*! \brief Serialize this object to string*/
Guolin Ke's avatar
Guolin Ke committed
219
  std::string ToString() const;
Guolin Ke's avatar
Guolin Ke committed
220

wxchan's avatar
wxchan committed
221
  /*! \brief Serialize this object to json*/
Guolin Ke's avatar
Guolin Ke committed
222
  std::string ToJSON() const;
wxchan's avatar
wxchan committed
223

224
  /*! \brief Serialize this object to if-else statement*/
Guolin Ke's avatar
Guolin Ke committed
225
  std::string ToIfElse(int index, bool predict_leaf_index) const;
226

Guolin Ke's avatar
Guolin Ke committed
227
  inline static bool IsZero(double fval) {
228
    return (fval >= -kZeroThreshold && fval <= kZeroThreshold);
Guolin Ke's avatar
Guolin Ke committed
229
230
  }

231
  inline static double MaybeRoundToZero(double fval) {
232
    return IsZero(fval) ? 0 : fval;
233
234
  }

Guolin Ke's avatar
Guolin Ke committed
235
236
237
238
239
240
241
  inline static bool GetDecisionType(int8_t decision_type, int8_t mask) {
    return (decision_type & mask) > 0;
  }

  inline static void SetDecisionType(int8_t* decision_type, bool input, int8_t mask) {
    if (input) {
      (*decision_type) |= mask;
Guolin Ke's avatar
Guolin Ke committed
242
    } else {
Guolin Ke's avatar
Guolin Ke committed
243
      (*decision_type) &= (127 - mask);
Guolin Ke's avatar
Guolin Ke committed
244
245
246
    }
  }

Guolin Ke's avatar
Guolin Ke committed
247
248
249
250
251
252
253
254
255
  inline static int8_t GetMissingType(int8_t decision_type) {
    return (decision_type >> 2) & 3;
  }

  inline static void SetMissingType(int8_t* decision_type, int8_t input) {
    (*decision_type) &= 3;
    (*decision_type) |= (input << 2);
  }

256
257
  void RecomputeMaxDepth();

258
  int NextLeafId() const { return num_leaves_; }
259

Nikita Titov's avatar
Nikita Titov committed
260
 private:
Guolin Ke's avatar
Guolin Ke committed
261
  std::string NumericalDecisionIfElse(int node) const;
Guolin Ke's avatar
Guolin Ke committed
262

Guolin Ke's avatar
Guolin Ke committed
263
  std::string CategoricalDecisionIfElse(int node) const;
264
265
266

  inline int NumericalDecision(double fval, int node) const {
    uint8_t missing_type = GetMissingType(decision_type_[node]);
267
268
    if (std::isnan(fval) && missing_type != MissingType::NaN) {
      fval = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
269
    }
270
271
    if ((missing_type == MissingType::Zero && IsZero(fval))
        || (missing_type == MissingType::NaN && std::isnan(fval))) {
272
273
      if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
        return left_child_[node];
Guolin Ke's avatar
Guolin Ke committed
274
      } else {
275
        return right_child_[node];
Guolin Ke's avatar
Guolin Ke committed
276
277
      }
    }
278
279
280
281
282
    if (fval <= threshold_[node]) {
      return left_child_[node];
    } else {
      return right_child_[node];
    }
Guolin Ke's avatar
Guolin Ke committed
283
  }
Guolin Ke's avatar
Guolin Ke committed
284

285
286
  inline int NumericalDecisionInner(uint32_t fval, int node, uint32_t default_bin, uint32_t max_bin) const {
    uint8_t missing_type = GetMissingType(decision_type_[node]);
287
288
    if ((missing_type == MissingType::Zero && fval == default_bin)
        || (missing_type == MissingType::NaN && fval == max_bin)) {
289
290
291
292
293
294
295
296
      if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
        return left_child_[node];
      } else {
        return right_child_[node];
      }
    }
    if (fval <= threshold_in_bin_[node]) {
      return left_child_[node];
297
    } else {
298
      return right_child_[node];
299
300
    }
  }
Guolin Ke's avatar
Guolin Ke committed
301

302
303
304
305
306
307
308
  inline int CategoricalDecision(double fval, int node) const {
    uint8_t missing_type = GetMissingType(decision_type_[node]);
    int int_fval = static_cast<int>(fval);
    if (int_fval < 0) {
      return right_child_[node];;
    } else if (std::isnan(fval)) {
      // NaN is always in the right
309
      if (missing_type == MissingType::NaN) {
310
311
312
313
        return right_child_[node];
      }
      int_fval = 0;
    }
314
    int cat_idx = static_cast<int>(threshold_[node]);
315
316
    if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                             cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], int_fval)) {
317
318
319
320
      return left_child_[node];
    }
    return right_child_[node];
  }
Guolin Ke's avatar
Guolin Ke committed
321

322
  inline int CategoricalDecisionInner(uint32_t fval, int node) const {
323
    int cat_idx = static_cast<int>(threshold_in_bin_[node]);
324
325
    if (Common::FindInBitset(cat_threshold_inner_.data() + cat_boundaries_inner_[cat_idx],
                             cat_boundaries_inner_[cat_idx + 1] - cat_boundaries_inner_[cat_idx], fval)) {
326
327
328
329
      return left_child_[node];
    }
    return right_child_[node];
  }
Guolin Ke's avatar
Guolin Ke committed
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
  inline int Decision(double fval, int node) const {
    if (GetDecisionType(decision_type_[node], kCategoricalMask)) {
      return CategoricalDecision(fval, node);
    } else {
      return NumericalDecision(fval, node);
    }
  }

  inline int DecisionInner(uint32_t fval, int node, uint32_t default_bin, uint32_t max_bin) const {
    if (GetDecisionType(decision_type_[node], kCategoricalMask)) {
      return CategoricalDecisionInner(fval, node);
    } else {
      return NumericalDecisionInner(fval, node, default_bin, max_bin);
    }
  }

347
348
  inline void Split(int leaf, int feature, int real_feature, double left_value, double right_value, int left_cnt, int right_cnt,
                    double left_weight, double right_weight, float gain);
Guolin Ke's avatar
Guolin Ke committed
349
  /*!
Qiwei Ye's avatar
Qiwei Ye committed
350
  * \brief Find leaf index of which record belongs by features
Guolin Ke's avatar
Guolin Ke committed
351
352
353
  * \param feature_values Feature value of this record
  * \return Leaf index
  */
354
  inline int GetLeaf(const double* feature_values) const;
355
  inline int GetLeafByMap(const std::unordered_map<int, double>& feature_values) const;
Guolin Ke's avatar
Guolin Ke committed
356

wxchan's avatar
wxchan committed
357
  /*! \brief Serialize one node to json*/
Guolin Ke's avatar
Guolin Ke committed
358
  std::string NodeToJSON(int index) const;
wxchan's avatar
wxchan committed
359

360
  /*! \brief Serialize one node to if-else statement*/
Guolin Ke's avatar
Guolin Ke committed
361
  std::string NodeToIfElse(int index, bool predict_leaf_index) const;
Guolin Ke's avatar
Guolin Ke committed
362

Guolin Ke's avatar
Guolin Ke committed
363
  std::string NodeToIfElseByMap(int index, bool predict_leaf_index) const;
364

365
  double ExpectedValue() const;
Guolin Ke's avatar
Guolin Ke committed
366

367
368
  /*! \brief This is used fill in leaf_depth_ after reloading a model*/
  inline void RecomputeLeafDepths(int node = 0, int depth = 0);
Guolin Ke's avatar
Guolin Ke committed
369
370
371
372
373
374
375
376
377
378

  /*!
  * \brief Used by TreeSHAP for data we keep about our decision path
  */
  struct PathElement {
    int feature_index;
    double zero_fraction;
    double one_fraction;

    // note that pweight is included for convenience and is not tied with the other attributes,
379
    // the pweight of the i'th path element is the permutation weight of paths with i-1 ones in them
Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
384
385
    double pweight;

    PathElement() {}
    PathElement(int i, double z, double o, double w) : feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
  };

386
  /*! \brief Polynomial time algorithm for SHAP values (arXiv:1706.06060)*/
Guolin Ke's avatar
Guolin Ke committed
387
388
389
390
  void TreeSHAP(const double *feature_values, double *phi,
                int node, int unique_depth,
                PathElement *parent_unique_path, double parent_zero_fraction,
                double parent_one_fraction, int parent_feature_index) const;
391

392
393
394
395
396
397
  void TreeSHAPByMap(const std::unordered_map<int, double>& feature_values,
                     std::unordered_map<int, double>* phi,
                     int node, int unique_depth,
                     PathElement *parent_unique_path, double parent_zero_fraction,
                     double parent_one_fraction, int parent_feature_index) const;

398
  /*! \brief Extend our decision path with a fraction of one and zero extensions for TreeSHAP*/
Guolin Ke's avatar
Guolin Ke committed
399
400
  static void ExtendPath(PathElement *unique_path, int unique_depth,
                         double zero_fraction, double one_fraction, int feature_index);
401
402

  /*! \brief Undo a previous extension of the decision path for TreeSHAP*/
Guolin Ke's avatar
Guolin Ke committed
403
  static void UnwindPath(PathElement *unique_path, int unique_depth, int path_index);
404

405
  /*! determine what the total permutation weight would be if we unwound a previous extension in the decision path*/
Guolin Ke's avatar
Guolin Ke committed
406
  static double UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index);
407

Guolin Ke's avatar
Guolin Ke committed
408
409
  /*! \brief Number of max leaves*/
  int max_leaves_;
410
  /*! \brief Number of current leaves*/
Guolin Ke's avatar
Guolin Ke committed
411
412
413
  int num_leaves_;
  // following values used for non-leaf node
  /*! \brief A non-leaf node's left child */
Guolin Ke's avatar
Guolin Ke committed
414
  std::vector<int> left_child_;
Guolin Ke's avatar
Guolin Ke committed
415
  /*! \brief A non-leaf node's right child */
Guolin Ke's avatar
Guolin Ke committed
416
  std::vector<int> right_child_;
Guolin Ke's avatar
Guolin Ke committed
417
  /*! \brief A non-leaf node's split feature */
Guolin Ke's avatar
Guolin Ke committed
418
  std::vector<int> split_feature_inner_;
Guolin Ke's avatar
Guolin Ke committed
419
  /*! \brief A non-leaf node's split feature, the original index */
Guolin Ke's avatar
Guolin Ke committed
420
  std::vector<int> split_feature_;
Guolin Ke's avatar
Guolin Ke committed
421
  /*! \brief A non-leaf node's split threshold in bin */
Guolin Ke's avatar
Guolin Ke committed
422
  std::vector<uint32_t> threshold_in_bin_;
Guolin Ke's avatar
Guolin Ke committed
423
  /*! \brief A non-leaf node's split threshold in feature value */
Guolin Ke's avatar
Guolin Ke committed
424
  std::vector<double> threshold_;
425
  int num_cat_;
426
427
428
429
  std::vector<int> cat_boundaries_inner_;
  std::vector<uint32_t> cat_threshold_inner_;
  std::vector<int> cat_boundaries_;
  std::vector<uint32_t> cat_threshold_;
430
  /*! \brief Store the information for categorical feature handle and missing value handle. */
431
  std::vector<int8_t> decision_type_;
Guolin Ke's avatar
Guolin Ke committed
432
  /*! \brief A non-leaf node's split gain */
433
  std::vector<float> split_gain_;
Guolin Ke's avatar
Guolin Ke committed
434
435
  // used for leaf node
  /*! \brief The parent of leaf */
Guolin Ke's avatar
Guolin Ke committed
436
  std::vector<int> leaf_parent_;
Guolin Ke's avatar
Guolin Ke committed
437
  /*! \brief Output of leaves */
Guolin Ke's avatar
Guolin Ke committed
438
  std::vector<double> leaf_value_;
439
440
  /*! \brief weight of leaves */
  std::vector<double> leaf_weight_;
Guolin Ke's avatar
Guolin Ke committed
441
  /*! \brief DataCount of leaves */
442
  std::vector<int> leaf_count_;
Guolin Ke's avatar
Guolin Ke committed
443
444
  /*! \brief Output of non-leaf nodes */
  std::vector<double> internal_value_;
445
446
  /*! \brief weight of non-leaf nodes */
  std::vector<double> internal_weight_;
Guolin Ke's avatar
Guolin Ke committed
447
  /*! \brief DataCount of non-leaf nodes */
448
  std::vector<int> internal_count_;
Guolin Ke's avatar
Guolin Ke committed
449
  /*! \brief Depth for leaves */
Guolin Ke's avatar
Guolin Ke committed
450
  std::vector<int> leaf_depth_;
451
452
453
454
  /*! \brief whether to keep track of ancestor nodes for each leaf (only needed when feature interactions are restricted) */
  bool track_branch_features_;
  /*! \brief Features on leaf's branch, original index */
  std::vector<std::vector<int>> branch_features_;
Guolin Ke's avatar
Guolin Ke committed
455
  double shrinkage_;
456
  int max_depth_;
Guolin Ke's avatar
Guolin Ke committed
457
458
};

459
inline void Tree::Split(int leaf, int feature, int real_feature,
460
                        double left_value, double right_value, int left_cnt, int right_cnt,
461
                        double left_weight, double right_weight, float gain) {
462
463
464
465
466
467
468
469
470
471
472
473
474
475
  int new_node_idx = num_leaves_ - 1;
  // update parent info
  int parent = leaf_parent_[leaf];
  if (parent >= 0) {
    // if cur node is left child
    if (left_child_[parent] == ~leaf) {
      left_child_[parent] = new_node_idx;
    } else {
      right_child_[parent] = new_node_idx;
    }
  }
  // add new node
  split_feature_inner_[new_node_idx] = feature;
  split_feature_[new_node_idx] = real_feature;
Guolin Ke's avatar
Guolin Ke committed
476
  split_gain_[new_node_idx] = gain;
477
478
479
480
481
482
483
  // add two new leaves
  left_child_[new_node_idx] = ~leaf;
  right_child_[new_node_idx] = ~num_leaves_;
  // update new leaves
  leaf_parent_[leaf] = new_node_idx;
  leaf_parent_[num_leaves_] = new_node_idx;
  // save current leaf value to internal node before change
484
  internal_weight_[new_node_idx] = leaf_weight_[leaf];
485
486
487
  internal_value_[new_node_idx] = leaf_value_[leaf];
  internal_count_[new_node_idx] = left_cnt + right_cnt;
  leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value;
488
  leaf_weight_[leaf] = left_weight;
489
490
  leaf_count_[leaf] = left_cnt;
  leaf_value_[num_leaves_] = std::isnan(right_value) ? 0.0f : right_value;
491
  leaf_weight_[num_leaves_] = right_weight;
492
493
494
495
  leaf_count_[num_leaves_] = right_cnt;
  // update leaf depth
  leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
  leaf_depth_[leaf]++;
496
497
498
499
500
  if (track_branch_features_) {
    branch_features_[num_leaves_] = branch_features_[leaf];
    branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]);
    branch_features_[leaf].push_back(split_feature_[new_node_idx]);
  }
501
502
}

503
inline double Tree::Predict(const double* feature_values) const {
Guolin Ke's avatar
Guolin Ke committed
504
505
506
507
  if (num_leaves_ > 1) {
    int leaf = GetLeaf(feature_values);
    return LeafOutput(leaf);
  } else {
508
    return leaf_value_[0];
Guolin Ke's avatar
Guolin Ke committed
509
  }
Guolin Ke's avatar
Guolin Ke committed
510
511
}

512
513
514
515
516
517
518
519
520
inline double Tree::PredictByMap(const std::unordered_map<int, double>& feature_values) const {
  if (num_leaves_ > 1) {
    int leaf = GetLeafByMap(feature_values);
    return LeafOutput(leaf);
  } else {
    return leaf_value_[0];
  }
}

521
inline int Tree::PredictLeafIndex(const double* feature_values) const {
Guolin Ke's avatar
Guolin Ke committed
522
523
524
525
526
527
  if (num_leaves_ > 1) {
    int leaf = GetLeaf(feature_values);
    return leaf;
  } else {
    return 0;
  }
wxchan's avatar
wxchan committed
528
529
}

530
531
532
533
534
535
536
537
538
inline int Tree::PredictLeafIndexByMap(const std::unordered_map<int, double>& feature_values) const {
  if (num_leaves_ > 1) {
    int leaf = GetLeafByMap(feature_values);
    return leaf;
  } else {
    return 0;
  }
}

539
540
inline void Tree::PredictContrib(const double* feature_values, int num_features, double* output) {
  output[num_features] += ExpectedValue();
541
  // Run the recursion with preallocated space for the unique path data
542
  if (num_leaves_ > 1) {
543
    CHECK_GE(max_depth_, 0);
544
545
546
    const int max_path_len = max_depth_ + 1;
    std::vector<PathElement> unique_path_data(max_path_len*(max_path_len + 1) / 2);
    TreeSHAP(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1);
547
548
549
  }
}

550
551
552
553
554
555
556
557
558
559
560
561
inline void Tree::PredictContribByMap(const std::unordered_map<int, double>& feature_values,
                                      int num_features, std::unordered_map<int, double>* output) {
  (*output)[num_features] += ExpectedValue();
  // Run the recursion with preallocated space for the unique path data
  if (num_leaves_ > 1) {
    CHECK_GE(max_depth_, 0);
    const int max_path_len = max_depth_ + 1;
    std::vector<PathElement> unique_path_data(max_path_len*(max_path_len + 1) / 2);
    TreeSHAPByMap(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1);
  }
}

562
563
564
565
566
inline void Tree::RecomputeLeafDepths(int node, int depth) {
  if (node == 0) leaf_depth_.resize(num_leaves());
  if (node < 0) {
    leaf_depth_[~node] = depth;
  } else {
Guolin Ke's avatar
Guolin Ke committed
567
568
    RecomputeLeafDepths(left_child_[node], depth + 1);
    RecomputeLeafDepths(right_child_[node], depth + 1);
569
  }
570
571
}

572
inline int Tree::GetLeaf(const double* feature_values) const {
Guolin Ke's avatar
Guolin Ke committed
573
  int node = 0;
574
  if (num_cat_ > 0) {
Guolin Ke's avatar
Guolin Ke committed
575
    while (node >= 0) {
576
      node = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
577
578
579
    }
  } else {
    while (node >= 0) {
580
      node = NumericalDecision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
581
582
583
584
585
    }
  }
  return ~node;
}

586
587
588
589
590
591
592
593
594
595
596
597
598
599
inline int Tree::GetLeafByMap(const std::unordered_map<int, double>& feature_values) const {
  int node = 0;
  if (num_cat_ > 0) {
    while (node >= 0) {
      node = Decision(feature_values.count(split_feature_[node]) > 0 ? feature_values.at(split_feature_[node]) : 0.0f, node);
    }
  } else {
    while (node >= 0) {
      node = NumericalDecision(feature_values.count(split_feature_[node]) > 0 ? feature_values.at(split_feature_[node]) : 0.0f, node);
    }
  }
  return ~node;
}

Guolin Ke's avatar
Guolin Ke committed
600
601
}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
602
#endif   // LightGBM_TREE_H_