feature_histogram.hpp 78.3 KB
Newer Older
1
2
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
3
4
 * Licensed under the MIT License. See LICENSE file in the project root for
 * license information.
5
 */
Guolin Ke's avatar
Guolin Ke committed
6
7
8
#ifndef LIGHTGBM_TREELEARNER_FEATURE_HISTOGRAM_HPP_
#define LIGHTGBM_TREELEARNER_FEATURE_HISTOGRAM_HPP_

9
10
11
12
#include <LightGBM/bin.h>
#include <LightGBM/dataset.h>
#include <LightGBM/utils/array_args.h>

13
#include <algorithm>
14
#include <cmath>
15
16
17
18
19
#include <cstring>
#include <memory>
#include <utility>
#include <vector>

20
#include "monotone_constraints.hpp"
Nikita Titov's avatar
Nikita Titov committed
21
#include "split_info.hpp"
Guolin Ke's avatar
Guolin Ke committed
22

23
namespace LightGBM {
Guolin Ke's avatar
Guolin Ke committed
24

Guolin Ke's avatar
Guolin Ke committed
25
class FeatureMetainfo {
26
 public:
Guolin Ke's avatar
Guolin Ke committed
27
  int num_bin;
Guolin Ke's avatar
Guolin Ke committed
28
  MissingType missing_type;
29
  int8_t offset = 0;
Guolin Ke's avatar
Guolin Ke committed
30
  uint32_t default_bin;
31
32
  int8_t monotone_type = 0;
  double penalty = 1.0;
Guolin Ke's avatar
Guolin Ke committed
33
  /*! \brief pointer of tree config */
Guolin Ke's avatar
Guolin Ke committed
34
  const Config* config;
35
  BinType bin_type;
36
37
  /*! \brief random number generator for extremely randomized trees */
  mutable Random rand;
Guolin Ke's avatar
Guolin Ke committed
38
};
Guolin Ke's avatar
Guolin Ke committed
39
/*!
40
41
42
 * \brief FeatureHistogram is used to construct and store a histogram for a
 * feature.
 */
Guolin Ke's avatar
Guolin Ke committed
43
class FeatureHistogram {
44
 public:
45
  FeatureHistogram() { data_ = nullptr; }
Guolin Ke's avatar
Guolin Ke committed
46

47
  ~FeatureHistogram() {}
Guolin Ke's avatar
Guolin Ke committed
48

Guolin Ke's avatar
Guolin Ke committed
49
50
51
52
53
  /*! \brief Disable copy */
  FeatureHistogram& operator=(const FeatureHistogram&) = delete;
  /*! \brief Disable copy */
  FeatureHistogram(const FeatureHistogram&) = delete;

54
55
56
57
58
59
60
61
62
63
64
65
  /*!
   * \brief Init the feature histogram
   * \param feature the feature data for this histogram
   * \param min_num_data_one_leaf minimal number of data in one leaf
   */
  void Init(hist_t* data, int16_t* data_int16, const FeatureMetainfo* meta) {
    meta_ = meta;
    data_ = data;
    data_int16_ = data_int16;
    ResetFunc();
  }

Guolin Ke's avatar
Guolin Ke committed
66
  /*!
67
68
69
70
   * \brief Init the feature histogram
   * \param feature the feature data for this histogram
   * \param min_num_data_one_leaf minimal number of data in one leaf
   */
71
  void Init(hist_t* data, const FeatureMetainfo* meta) {
Guolin Ke's avatar
Guolin Ke committed
72
73
    meta_ = meta;
    data_ = data;
74
    data_int16_ = nullptr;
75
76
77
78
    ResetFunc();
  }

  void ResetFunc() {
79
    if (meta_->bin_type == BinType::NumericalBin) {
80
      FuncForNumrical();
81
    } else {
82
      FuncForCategorical();
83
    }
Guolin Ke's avatar
Guolin Ke committed
84
85
  }

86
  hist_t* RawData() { return data_; }
87

88
89
90
91
  int32_t* RawDataInt32() { return reinterpret_cast<int32_t*>(data_); }

  int16_t* RawDataInt16() { return data_int16_; }

Guolin Ke's avatar
Guolin Ke committed
92
  /*!
93
94
95
   * \brief Subtract current histograms with other
   * \param other The histogram that want to subtract
   */
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
  template <bool USE_DIST_GRAD = false,
    typename THIS_HIST_T = hist_t, typename OTHER_HIST_T = hist_t, typename RESULT_HIST_T = hist_t,
    int THIS_HIST_BITS = 0, int OTHER_HIST_BITS = 0, int RESULT_HIST_BITS = 0>
  void Subtract(const FeatureHistogram& other, const int32_t* buffer = nullptr) {
    if (USE_DIST_GRAD) {
      const THIS_HIST_T* this_int_data = THIS_HIST_BITS == 16 ?
        reinterpret_cast<const THIS_HIST_T*>(data_int16_) :
        (RESULT_HIST_BITS == 16 ?
          reinterpret_cast<const THIS_HIST_T*>(buffer) :
          reinterpret_cast<const THIS_HIST_T*>(data_));
      const OTHER_HIST_T* other_int_data = OTHER_HIST_BITS == 16 ?
        reinterpret_cast<OTHER_HIST_T*>(other.data_int16_) :
        reinterpret_cast<OTHER_HIST_T*>(other.data_);
      RESULT_HIST_T* result_int_data = RESULT_HIST_BITS == 16 ?
        reinterpret_cast<RESULT_HIST_T*>(data_int16_) :
        reinterpret_cast<RESULT_HIST_T*>(data_);
      if (THIS_HIST_BITS == 32 && OTHER_HIST_BITS == 16 && RESULT_HIST_BITS == 32) {
        for (int i = 0; i < meta_->num_bin - meta_->offset; ++i) {
          const int32_t other_grad_hess = static_cast<int32_t>(other_int_data[i]);
          const int64_t this_grad_hess = this_int_data[i];
          const int64_t other_grad_hess_int64 =
            (static_cast<int64_t>(static_cast<int16_t>(other_grad_hess >> 16)) << 32) |
            (static_cast<int64_t>(other_grad_hess & 0x0000ffff));
          const int64_t result_grad_hess = this_grad_hess - other_grad_hess_int64;
          result_int_data[i] = result_grad_hess;
        }
      } else if (THIS_HIST_BITS == 32 && OTHER_HIST_BITS == 16 && RESULT_HIST_BITS == 16) {
        for (int i = 0; i < meta_->num_bin - meta_->offset; ++i) {
          const int32_t other_grad_hess = static_cast<int32_t>(other_int_data[i]);
          const int64_t this_grad_hess = this_int_data[i];
          const int64_t other_grad_hess_int64 =
            (static_cast<int64_t>(static_cast<int16_t>(other_grad_hess >> 16)) << 32) |
            (static_cast<int64_t>(other_grad_hess & 0x0000ffff));
          const int64_t result_grad_hess = this_grad_hess - other_grad_hess_int64;
          const int32_t result_grad_hess_int32 =
            (static_cast<int32_t>(result_grad_hess >> 32) << 16) |
            static_cast<int32_t>(result_grad_hess & 0x00000000ffffffff);
          result_int_data[i] = result_grad_hess_int32;
        }
      } else {
        for (int i = 0; i < meta_->num_bin - meta_->offset; ++i) {
          result_int_data[i] = this_int_data[i] - other_int_data[i];
        }
      }
    } else {
      for (int i = 0; i < (meta_->num_bin - meta_->offset) * 2; ++i) {
        data_[i] -= other.data_[i];
      }
    }
  }

  void CopyToBuffer(int32_t* buffer) {
    const int64_t* data_ptr = reinterpret_cast<const int64_t*>(data_);
    int64_t* buffer_ptr = reinterpret_cast<int64_t*>(buffer);
    for (int i = 0; i < meta_->num_bin - meta_->offset; ++i) {
      buffer_ptr[i] = data_ptr[i];
    }
  }

  void CopyFromInt16ToInt32(char* buffer) {
    const int32_t* int16_data = reinterpret_cast<const int32_t*>(RawDataInt16());
    int64_t* int32_data = reinterpret_cast<int64_t*>(buffer);
    for (int i = 0; i < meta_->num_bin - meta_->offset; ++i) {
      const int32_t int16_val = int16_data[i];
      int32_data[i] = (static_cast<int64_t>(static_cast<int16_t>(int16_val >> 16)) << 32) |
        static_cast<int64_t>(int16_val & 0x0000ffff);
Guolin Ke's avatar
Guolin Ke committed
162
163
    }
  }
164

165
166
  void FindBestThreshold(double sum_gradient, double sum_hessian,
                         data_size_t num_data,
167
                         const FeatureConstraint* constraints,
Belinda Trotta's avatar
Belinda Trotta committed
168
                         double parent_output,
169
                         SplitInfo* output) {
Guolin Ke's avatar
Guolin Ke committed
170
    output->default_left = true;
Guolin Ke's avatar
Guolin Ke committed
171
    output->gain = kMinScore;
172
    find_best_threshold_fun_(sum_gradient, sum_hessian + 2 * kEpsilon, num_data,
Belinda Trotta's avatar
Belinda Trotta committed
173
                             constraints, parent_output, output);
Guolin Ke's avatar
Guolin Ke committed
174
    output->gain *= meta_->penalty;
175
176
  }

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
  void FindBestThresholdInt(int64_t sum_gradient_and_hessian,
                            double grad_scale, double hess_scale,
                            const uint8_t num_bits_bin,
                            const uint8_t num_bits_acc,
                            data_size_t num_data,
                            const FeatureConstraint* constraints,
                            double parent_output,
                            SplitInfo* output) {
    output->default_left = true;
    output->gain = kMinScore;
    int_find_best_threshold_fun_(sum_gradient_and_hessian, grad_scale, hess_scale, num_bits_bin, num_bits_acc, num_data,
                             constraints, parent_output, output);
    output->gain *= meta_->penalty;
  }

Belinda Trotta's avatar
Belinda Trotta committed
192
  template <bool USE_RAND, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
193
  double BeforeNumerical(double sum_gradient, double sum_hessian, double parent_output, data_size_t num_data,
194
                        SplitInfo* output, int* rand_threshold) {
Guolin Ke's avatar
Guolin Ke committed
195
    is_splittable_ = false;
196
    output->monotone_type = meta_->monotone_type;
Belinda Trotta's avatar
Belinda Trotta committed
197
198
199
200

    double gain_shift = GetLeafGain<USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
        sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2,
        meta_->config->max_delta_step, meta_->config->path_smooth, num_data, parent_output);
201
202
203
204
205
    *rand_threshold = 0;
    if (USE_RAND) {
      if (meta_->num_bin - 2 > 0) {
        *rand_threshold = meta_->rand.NextInt(0, meta_->num_bin - 2);
      }
206
    }
207
208
209
    return gain_shift + meta_->config->min_gain_to_split;
  }

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
  template <bool USE_RAND, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
  double BeforeNumericalInt(int64_t sum_gradient_and_hessian, double grad_scale, double hess_scale, double parent_output, data_size_t num_data,
                        SplitInfo* output, int* rand_threshold) {
    is_splittable_ = false;
    output->monotone_type = meta_->monotone_type;
    const int32_t int_sum_gradient = static_cast<int32_t>(sum_gradient_and_hessian >> 32);
    const uint32_t int_sum_hessian = static_cast<uint32_t>(sum_gradient_and_hessian & 0x00000000ffffffff);
    const double sum_gradient = static_cast<double>(int_sum_gradient) * grad_scale;
    const double sum_hessian = static_cast<double>(int_sum_hessian) * hess_scale;
    double gain_shift = GetLeafGain<USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
        sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2,
        meta_->config->max_delta_step, meta_->config->path_smooth, num_data, parent_output);
    *rand_threshold = 0;
    if (USE_RAND) {
      if (meta_->num_bin - 2 > 0) {
        *rand_threshold = meta_->rand.NextInt(0, meta_->num_bin - 2);
      }
    }
    return gain_shift + meta_->config->min_gain_to_split;
  }

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
  void FuncForNumrical() {
    if (meta_->config->extra_trees) {
      if (meta_->config->monotone_constraints.empty()) {
        FuncForNumricalL1<true, false>();
      } else {
        FuncForNumricalL1<true, true>();
      }
    } else {
      if (meta_->config->monotone_constraints.empty()) {
        FuncForNumricalL1<false, false>();
      } else {
        FuncForNumricalL1<false, true>();
      }
    }
  }
  template <bool USE_RAND, bool USE_MC>
  void FuncForNumricalL1() {
    if (meta_->config->lambda_l1 > 0) {
      if (meta_->config->max_delta_step > 0) {
        FuncForNumricalL2<USE_RAND, USE_MC, true, true>();
      } else {
        FuncForNumricalL2<USE_RAND, USE_MC, true, false>();
      }
    } else {
      if (meta_->config->max_delta_step > 0) {
        FuncForNumricalL2<USE_RAND, USE_MC, false, true>();
      } else {
        FuncForNumricalL2<USE_RAND, USE_MC, false, false>();
      }
    }
  }

  template <bool USE_RAND, bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT>
  void FuncForNumricalL2() {
Belinda Trotta's avatar
Belinda Trotta committed
265
266
267
268
269
270
271
272
273
    if (meta_->config->path_smooth > kEpsilon) {
      FuncForNumricalL3<USE_RAND, USE_MC, USE_L1, USE_MAX_OUTPUT, true>();
    } else {
      FuncForNumricalL3<USE_RAND, USE_MC, USE_L1, USE_MAX_OUTPUT, false>();
    }
  }

  template <bool USE_RAND, bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
  void FuncForNumricalL3() {
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
  if (meta_->config->use_quantized_grad) {
#define TEMPLATE_PREFIX_INT USE_RAND, USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING
#define LAMBDA_ARGUMENTS_INT                                         \
  int64_t sum_gradient_and_hessian, double grad_scale, double hess_scale, const uint8_t hist_bits_bin, const uint8_t hist_bits_acc, data_size_t num_data, \
      const FeatureConstraint* constraints, double parent_output, SplitInfo *output
#define BEFORE_ARGUMENTS_INT sum_gradient_and_hessian, grad_scale, hess_scale, parent_output, num_data, output, &rand_threshold
#define FUNC_ARGUMENTS_INT                                                      \
  sum_gradient_and_hessian, grad_scale, hess_scale, num_data, constraints, min_gain_shift, \
      output, rand_threshold, parent_output

      if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
        if (meta_->missing_type == MissingType::Zero) {
          int_find_best_threshold_fun_ = [=](LAMBDA_ARGUMENTS_INT) {
            int rand_threshold = 0;
            double min_gain_shift =
                BeforeNumericalInt<USE_RAND, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
                    BEFORE_ARGUMENTS_INT);
            if (hist_bits_acc <= 16) {
              CHECK_LE(hist_bits_bin, 16);
              FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, true, false, int32_t, int32_t, int16_t, int16_t, 16, 16>(
                  FUNC_ARGUMENTS_INT);
              FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, false, true, false, int32_t, int32_t, int16_t, int16_t, 16, 16>(
                  FUNC_ARGUMENTS_INT);
            } else {
              if (hist_bits_bin == 32) {
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, true, false, int64_t, int64_t, int32_t, int32_t, 32, 32>(
                    FUNC_ARGUMENTS_INT);
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, false, true, false, int64_t, int64_t, int32_t, int32_t, 32, 32>(
                    FUNC_ARGUMENTS_INT);
              } else {
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, true, false, int32_t, int64_t, int16_t, int32_t, 16, 32>(
                    FUNC_ARGUMENTS_INT);
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, false, true, false, int32_t, int64_t, int16_t, int32_t, 16, 32>(
                    FUNC_ARGUMENTS_INT);
              }
            }
          };
        } else {
          int_find_best_threshold_fun_ = [=](LAMBDA_ARGUMENTS_INT) {
            int rand_threshold = 0;
            double min_gain_shift =
                BeforeNumericalInt<USE_RAND, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
                    BEFORE_ARGUMENTS_INT);
            if (hist_bits_acc <= 16) {
              CHECK_LE(hist_bits_bin, 16);
              FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, true, int32_t, int32_t, int16_t, int16_t, 16, 16>(
                  FUNC_ARGUMENTS_INT);
              FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, false, false, true, int32_t, int32_t, int16_t, int16_t, 16, 16>(
                  FUNC_ARGUMENTS_INT);
            } else {
              if (hist_bits_bin == 32) {
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, true, int64_t, int64_t, int32_t, int32_t, 32, 32>(
                    FUNC_ARGUMENTS_INT);
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, false, false, true, int64_t, int64_t, int32_t, int32_t, 32, 32>(
                    FUNC_ARGUMENTS_INT);
              } else {
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, true, int32_t, int64_t, int16_t, int32_t, 16, 32>(
                    FUNC_ARGUMENTS_INT);
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, false, false, true, int32_t, int64_t, int16_t, int32_t, 16, 32>(
                    FUNC_ARGUMENTS_INT);
              }
            }
          };
        }
      } else {
        if (meta_->missing_type != MissingType::NaN) {
          int_find_best_threshold_fun_ = [=](LAMBDA_ARGUMENTS_INT) {
            int rand_threshold = 0;
            double min_gain_shift =
                BeforeNumericalInt<USE_RAND, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
                    BEFORE_ARGUMENTS_INT);
            if (hist_bits_acc <= 16) {
              CHECK_LE(hist_bits_bin, 16);
              FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, false, int32_t, int32_t, int16_t, int16_t, 16, 16>(
                  FUNC_ARGUMENTS_INT);
            } else {
              if (hist_bits_bin == 32) {
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, false, int64_t, int64_t, int32_t, int32_t, 32, 32>(
                    FUNC_ARGUMENTS_INT);
              } else {
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, false, int32_t, int64_t, int16_t, int32_t, 16, 32>(
                    FUNC_ARGUMENTS_INT);
              }
            }
          };
        } else {
          int_find_best_threshold_fun_ = [=](LAMBDA_ARGUMENTS_INT) {
            int rand_threshold = 0;
            double min_gain_shift =
                BeforeNumericalInt<USE_RAND, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
                    BEFORE_ARGUMENTS_INT);
            if (hist_bits_acc <= 16) {
              CHECK_LE(hist_bits_bin, 16);
              FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, false, int32_t, int32_t, int16_t, int16_t, 16, 16>(
                  FUNC_ARGUMENTS_INT);
            } else {
              if (hist_bits_bin == 32) {
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, false, int64_t, int64_t, int32_t, int32_t, 32, 32>(
                    FUNC_ARGUMENTS_INT);
              } else {
                FindBestThresholdSequentiallyInt<TEMPLATE_PREFIX_INT, true, false, false, int32_t, int64_t, int16_t, int32_t, 16, 32>(
                    FUNC_ARGUMENTS_INT);
              }
            }
            output->default_left = false;
          };
        }
      }
#undef TEMPLATE_PREFIX_INT
#undef LAMBDA_ARGUMENTS_INT
#undef BEFORE_ARGUMENTS_INT
#undef FUNC_ARGURMENTS_INT
  } else {
Belinda Trotta's avatar
Belinda Trotta committed
387
#define TEMPLATE_PREFIX USE_RAND, USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING
388
389
#define LAMBDA_ARGUMENTS                                         \
  double sum_gradient, double sum_hessian, data_size_t num_data, \
390
      const FeatureConstraint* constraints, double parent_output, SplitInfo *output
Belinda Trotta's avatar
Belinda Trotta committed
391
#define BEFORE_ARGUMENTS sum_gradient, sum_hessian, parent_output, num_data, output, &rand_threshold
392
#define FUNC_ARGUMENTS                                                      \
Belinda Trotta's avatar
Belinda Trotta committed
393
394
  sum_gradient, sum_hessian, num_data, constraints, min_gain_shift, \
      output, rand_threshold, parent_output
395

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
      if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
        if (meta_->missing_type == MissingType::Zero) {
          find_best_threshold_fun_ = [=](LAMBDA_ARGUMENTS) {
            int rand_threshold = 0;
            double min_gain_shift =
                BeforeNumerical<USE_RAND, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
                    BEFORE_ARGUMENTS);
            FindBestThresholdSequentially<TEMPLATE_PREFIX, true, true, false>(
                FUNC_ARGUMENTS);
            FindBestThresholdSequentially<TEMPLATE_PREFIX, false, true, false>(
                FUNC_ARGUMENTS);
          };
        } else {
          find_best_threshold_fun_ = [=](LAMBDA_ARGUMENTS) {
            int rand_threshold = 0;
            double min_gain_shift =
                BeforeNumerical<USE_RAND, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
                    BEFORE_ARGUMENTS);
            FindBestThresholdSequentially<TEMPLATE_PREFIX, true, false, true>(
                FUNC_ARGUMENTS);
            FindBestThresholdSequentially<TEMPLATE_PREFIX, false, false, true>(
                FUNC_ARGUMENTS);
          };
        }
420
      } else {
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
        if (meta_->missing_type != MissingType::NaN) {
          find_best_threshold_fun_ = [=](LAMBDA_ARGUMENTS) {
            int rand_threshold = 0;
            double min_gain_shift =
                BeforeNumerical<USE_RAND, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
                    BEFORE_ARGUMENTS);
            FindBestThresholdSequentially<TEMPLATE_PREFIX, true, false, false>(
                FUNC_ARGUMENTS);
          };
        } else {
          find_best_threshold_fun_ = [=](LAMBDA_ARGUMENTS) {
            int rand_threshold = 0;
            double min_gain_shift =
                BeforeNumerical<USE_RAND, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
                    BEFORE_ARGUMENTS);
            FindBestThresholdSequentially<TEMPLATE_PREFIX, true, false, false>(
                FUNC_ARGUMENTS);
            output->default_left = false;
          };
        }
Guolin Ke's avatar
Guolin Ke committed
441
      }
442
443
444
445
#undef TEMPLATE_PREFIX
#undef LAMBDA_ARGUMENTS
#undef BEFORE_ARGUMENTS
#undef FUNC_ARGURMENTS
446
    }
Guolin Ke's avatar
Guolin Ke committed
447
  }
448

449
  void FuncForCategorical() {
450
    if (meta_->config->extra_trees) {
451
452
453
454
455
456
457
458
459
460
461
462
463
      if (meta_->config->monotone_constraints.empty()) {
        FuncForCategoricalL1<true, false>();
      } else {
        FuncForCategoricalL1<true, true>();
      }
    } else {
      if (meta_->config->monotone_constraints.empty()) {
        FuncForCategoricalL1<false, false>();
      } else {
        FuncForCategoricalL1<false, true>();
      }
    }
  }
464

465
466
  template <bool USE_RAND, bool USE_MC>
  void FuncForCategoricalL1() {
Belinda Trotta's avatar
Belinda Trotta committed
467
468
469
470
471
472
473
474
475
    if (meta_->config->path_smooth > kEpsilon) {
      FuncForCategoricalL2<USE_RAND, USE_MC, true>();
    } else {
      FuncForCategoricalL2<USE_RAND, USE_MC, false>();
    }
  }

  template <bool USE_RAND, bool USE_MC, bool USE_SMOOTHING>
  void FuncForCategoricalL2() {
476
477
#define ARGUMENTS                                                      \
  std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, \
Belinda Trotta's avatar
Belinda Trotta committed
478
      std::placeholders::_4, std::placeholders::_5, std::placeholders::_6
479
480
481
482
    if (meta_->config->lambda_l1 > 0) {
      if (meta_->config->max_delta_step > 0) {
        find_best_threshold_fun_ =
            std::bind(&FeatureHistogram::FindBestThresholdCategoricalInner<
Belinda Trotta's avatar
Belinda Trotta committed
483
                          USE_RAND, USE_MC, true, true, USE_SMOOTHING>,
484
                      this, ARGUMENTS);
485
486
487
      } else {
        find_best_threshold_fun_ =
            std::bind(&FeatureHistogram::FindBestThresholdCategoricalInner<
Belinda Trotta's avatar
Belinda Trotta committed
488
                          USE_RAND, USE_MC, true, false, USE_SMOOTHING>,
489
                      this, ARGUMENTS);
490
      }
491
    } else {
492
493
494
      if (meta_->config->max_delta_step > 0) {
        find_best_threshold_fun_ =
            std::bind(&FeatureHistogram::FindBestThresholdCategoricalInner<
Belinda Trotta's avatar
Belinda Trotta committed
495
                          USE_RAND, USE_MC, false, true, USE_SMOOTHING>,
496
                      this, ARGUMENTS);
497
498
499
      } else {
        find_best_threshold_fun_ =
            std::bind(&FeatureHistogram::FindBestThresholdCategoricalInner<
Belinda Trotta's avatar
Belinda Trotta committed
500
                          USE_RAND, USE_MC, false, false, USE_SMOOTHING>,
501
                      this, ARGUMENTS);
502
      }
503
    }
504
#undef ARGUMENTS
505
506
  }

Belinda Trotta's avatar
Belinda Trotta committed
507
  template <bool USE_RAND, bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
508
509
510
  void FindBestThresholdCategoricalInner(double sum_gradient,
                                         double sum_hessian,
                                         data_size_t num_data,
511
                                         const FeatureConstraint* constraints,
Belinda Trotta's avatar
Belinda Trotta committed
512
                                         double parent_output,
513
514
                                         SplitInfo* output) {
    is_splittable_ = false;
Guolin Ke's avatar
Guolin Ke committed
515
    output->default_left = false;
516
    double best_gain = kMinScore;
517
    data_size_t best_left_count = 0;
ChenZhiyong's avatar
ChenZhiyong committed
518
519
    double best_sum_left_gradient = 0;
    double best_sum_left_hessian = 0;
Belinda Trotta's avatar
Belinda Trotta committed
520
    double gain_shift;
521
522
523
    if (USE_MC) {
      constraints->InitCumulativeConstraints(true);
    }
Belinda Trotta's avatar
Belinda Trotta committed
524
525
526
527
528
529
530
531
532
533
    if (USE_SMOOTHING) {
      gain_shift = GetLeafGainGivenOutput<USE_L1>(
          sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output);
    } else {
      // Need special case for no smoothing to preserve existing behaviour. If no smoothing, the parent output is calculated
      // with the larger categorical l2, whereas min_split_gain uses the original l2.
      gain_shift = GetLeafGain<USE_L1, USE_MAX_OUTPUT, false>(sum_gradient, sum_hessian,
          meta_->config->lambda_l1, meta_->config->lambda_l2, meta_->config->max_delta_step, 0,
          num_data, 0);
    }
534

Guolin Ke's avatar
Guolin Ke committed
535
    double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
536
537
538
539
    const int8_t offset = meta_->offset;
    const int bin_start = 1 - offset;
    const int bin_end = meta_->num_bin - offset;
    int used_bin = -1;
ChenZhiyong's avatar
ChenZhiyong committed
540

Guolin Ke's avatar
Guolin Ke committed
541
    std::vector<int> sorted_idx;
Guolin Ke's avatar
Guolin Ke committed
542
543
    double l2 = meta_->config->lambda_l2;
    bool use_onehot = meta_->num_bin <= meta_->config->max_cat_to_onehot;
544
545
    int best_threshold = -1;
    int best_dir = 1;
546
    const double cnt_factor = num_data / sum_hessian;
547
    int rand_threshold = 0;
548
    if (use_onehot) {
549
      if (USE_RAND) {
550
551
        if (bin_end - bin_start > 0) {
          rand_threshold = meta_->rand.NextInt(bin_start, bin_end);
552
553
        }
      }
554
      for (int t = bin_start; t < bin_end; ++t) {
555
556
        const auto grad = GET_GRAD(data_, t);
        const auto hess = GET_HESS(data_, t);
557
558
        data_size_t cnt =
            static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));
559
        // if data not enough, or sum hessian too small
560
        if (cnt < meta_->config->min_data_in_leaf ||
561
            hess < meta_->config->min_sum_hessian_in_leaf) {
562
          continue;
563
        }
564
        data_size_t other_count = num_data - cnt;
565
        // if data not enough
566
567
568
        if (other_count < meta_->config->min_data_in_leaf) {
          continue;
        }
ChenZhiyong's avatar
ChenZhiyong committed
569

570
        double sum_other_hessian = sum_hessian - hess - kEpsilon;
571
        // if sum hessian too small
572
        if (sum_other_hessian < meta_->config->min_sum_hessian_in_leaf) {
573
          continue;
574
        }
ChenZhiyong's avatar
ChenZhiyong committed
575

576
        double sum_other_gradient = sum_gradient - grad;
577
        if (USE_RAND) {
578
579
580
581
          if (t != rand_threshold) {
            continue;
          }
        }
582
        // current split gain
Belinda Trotta's avatar
Belinda Trotta committed
583
        double current_gain = GetSplitGains<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
584
585
            sum_other_gradient, sum_other_hessian, grad, hess + kEpsilon,
            meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
Belinda Trotta's avatar
Belinda Trotta committed
586
            constraints, 0, meta_->config->path_smooth, other_count, cnt, parent_output);
587
        // gain with split is worse than without split
588
589
590
        if (current_gain <= min_gain_shift) {
          continue;
        }
591

Andrew Ziem's avatar
Andrew Ziem committed
592
        // mark as able to be split
ChenZhiyong's avatar
ChenZhiyong committed
593
        is_splittable_ = true;
594
        // better split point
ChenZhiyong's avatar
ChenZhiyong committed
595
        if (current_gain > best_gain) {
596
          best_threshold = t;
597
598
599
          best_sum_left_gradient = grad;
          best_sum_left_hessian = hess + kEpsilon;
          best_left_count = cnt;
ChenZhiyong's avatar
ChenZhiyong committed
600
          best_gain = current_gain;
601
602
603
        }
      }
    } else {
604
      for (int i = bin_start; i < bin_end; ++i) {
605
606
        if (Common::RoundInt(GET_HESS(data_, i) * cnt_factor) >=
            meta_->config->cat_smooth) {
607
608
609
610
611
          sorted_idx.push_back(i);
        }
      }
      used_bin = static_cast<int>(sorted_idx.size());

Guolin Ke's avatar
Guolin Ke committed
612
      l2 += meta_->config->cat_l2;
613
614

      auto ctr_fun = [this](double sum_grad, double sum_hess) {
Guolin Ke's avatar
Guolin Ke committed
615
        return (sum_grad) / (sum_hess + meta_->config->cat_smooth);
616
      };
617
618
619
620
621
      std::stable_sort(
          sorted_idx.begin(), sorted_idx.end(), [this, &ctr_fun](int i, int j) {
            return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) <
                   ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j));
          });
622
623
624
625
626

      std::vector<int> find_direction(1, 1);
      std::vector<int> start_position(1, 0);
      find_direction.push_back(-1);
      start_position.push_back(used_bin - 1);
627
628
      const int max_num_cat =
          std::min(meta_->config->max_cat_threshold, (used_bin + 1) / 2);
629
      int max_threshold = std::max(std::min(max_num_cat, used_bin) - 1, 0);
630
      if (USE_RAND) {
631
        if (max_threshold > 0) {
632
          rand_threshold = meta_->rand.NextInt(0, max_threshold);
633
        }
634
      }
635

636
637
638
639
      is_splittable_ = false;
      for (size_t out_i = 0; out_i < find_direction.size(); ++out_i) {
        auto dir = find_direction[out_i];
        auto start_pos = start_position[out_i];
Guolin Ke's avatar
Guolin Ke committed
640
        data_size_t min_data_per_group = meta_->config->min_data_per_group;
641
642
643
644
645
646
647
        data_size_t cnt_cur_group = 0;
        double sum_left_gradient = 0.0f;
        double sum_left_hessian = kEpsilon;
        data_size_t left_count = 0;
        for (int i = 0; i < used_bin && i < max_num_cat; ++i) {
          auto t = sorted_idx[start_pos];
          start_pos += dir;
648
649
          const auto grad = GET_GRAD(data_, t);
          const auto hess = GET_HESS(data_, t);
650
651
          data_size_t cnt =
              static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));
652

653
654
655
656
          sum_left_gradient += grad;
          sum_left_hessian += hess;
          left_count += cnt;
          cnt_cur_group += cnt;
657

658
          if (left_count < meta_->config->min_data_in_leaf ||
659
              sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) {
660
            continue;
661
          }
662
          data_size_t right_count = num_data - left_count;
663
          if (right_count < meta_->config->min_data_in_leaf ||
664
              right_count < min_data_per_group) {
665
            break;
666
          }
667
668

          double sum_right_hessian = sum_hessian - sum_left_hessian;
669
670
671
          if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) {
            break;
          }
672

673
674
675
          if (cnt_cur_group < min_data_per_group) {
            continue;
          }
676
677
678
679

          cnt_cur_group = 0;

          double sum_right_gradient = sum_gradient - sum_left_gradient;
680
          if (USE_RAND) {
681
682
            if (i != rand_threshold) {
              continue;
683
            }
684
          }
Belinda Trotta's avatar
Belinda Trotta committed
685
          double current_gain = GetSplitGains<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
686
687
              sum_left_gradient, sum_left_hessian, sum_right_gradient,
              sum_right_hessian, meta_->config->lambda_l1, l2,
Belinda Trotta's avatar
Belinda Trotta committed
688
689
              meta_->config->max_delta_step, constraints, 0, meta_->config->path_smooth,
              left_count, right_count, parent_output);
690
691
692
          if (current_gain <= min_gain_shift) {
            continue;
          }
693
694
695
696
697
698
699
700
701
          is_splittable_ = true;
          if (current_gain > best_gain) {
            best_left_count = left_count;
            best_sum_left_gradient = sum_left_gradient;
            best_sum_left_hessian = sum_left_hessian;
            best_threshold = i;
            best_gain = current_gain;
            best_dir = dir;
          }
ChenZhiyong's avatar
ChenZhiyong committed
702
        }
703
704
      }
    }
705

706
    if (is_splittable_) {
Belinda Trotta's avatar
Belinda Trotta committed
707
708
709
      output->left_output = CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
          best_sum_left_gradient, best_sum_left_hessian,
          meta_->config->lambda_l1, l2, meta_->config->max_delta_step,
710
          constraints->LeftToBasicConstraint(), meta_->config->path_smooth, best_left_count, parent_output);
711
712
713
      output->left_count = best_left_count;
      output->left_sum_gradient = best_sum_left_gradient;
      output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
Belinda Trotta's avatar
Belinda Trotta committed
714
715
716
      output->right_output = CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
          sum_gradient - best_sum_left_gradient,
          sum_hessian - best_sum_left_hessian, meta_->config->lambda_l1, l2,
717
          meta_->config->max_delta_step, constraints->RightToBasicConstraint(), meta_->config->path_smooth,
Belinda Trotta's avatar
Belinda Trotta committed
718
          num_data - best_left_count, parent_output);
719
720
      output->right_count = num_data - best_left_count;
      output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
721
722
      output->right_sum_hessian =
          sum_hessian - best_sum_left_hessian - kEpsilon;
Guolin Ke's avatar
Guolin Ke committed
723
      output->gain = best_gain - min_gain_shift;
724
725
      if (use_onehot) {
        output->num_cat_threshold = 1;
726
        output->cat_threshold =
727
            std::vector<uint32_t>(1, static_cast<uint32_t>(best_threshold + offset));
ChenZhiyong's avatar
ChenZhiyong committed
728
      } else {
729
        output->num_cat_threshold = best_threshold + 1;
730
731
        output->cat_threshold =
            std::vector<uint32_t>(output->num_cat_threshold);
732
733
        if (best_dir == 1) {
          for (int i = 0; i < output->num_cat_threshold; ++i) {
734
            auto t = sorted_idx[i] + offset;
735
736
737
738
            output->cat_threshold[i] = t;
          }
        } else {
          for (int i = 0; i < output->num_cat_threshold; ++i) {
739
            auto t = sorted_idx[used_bin - 1 - i] + offset;
740
741
            output->cat_threshold[i] = t;
          }
ChenZhiyong's avatar
ChenZhiyong committed
742
743
        }
      }
Guolin Ke's avatar
Guolin Ke committed
744
      output->monotone_type = 0;
745
    }
746
747
  }

748
  void GatherInfoForThreshold(double sum_gradient, double sum_hessian,
749
                              uint32_t threshold, data_size_t num_data,
Belinda Trotta's avatar
Belinda Trotta committed
750
                              double parent_output, SplitInfo* output) {
751
    if (meta_->bin_type == BinType::NumericalBin) {
752
      GatherInfoForThresholdNumerical(sum_gradient, sum_hessian, threshold,
Belinda Trotta's avatar
Belinda Trotta committed
753
                                      num_data, parent_output, output);
754
    } else {
755
      GatherInfoForThresholdCategorical(sum_gradient, sum_hessian, threshold,
Belinda Trotta's avatar
Belinda Trotta committed
756
                                        num_data, parent_output, output);
757
758
759
760
    }
  }

  void GatherInfoForThresholdNumerical(double sum_gradient, double sum_hessian,
761
                                       uint32_t threshold, data_size_t num_data,
Belinda Trotta's avatar
Belinda Trotta committed
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
                                       double parent_output, SplitInfo* output) {
    bool use_smoothing = meta_->config->path_smooth > kEpsilon;
    if (use_smoothing) {
      GatherInfoForThresholdNumericalInner<true>(sum_gradient, sum_hessian,
                                                 threshold, num_data,
                                                 parent_output, output);
    } else {
      GatherInfoForThresholdNumericalInner<false>(sum_gradient, sum_hessian,
                                                  threshold, num_data,
                                                  parent_output, output);
    }
  }

  template<bool USE_SMOOTHING>
  void GatherInfoForThresholdNumericalInner(double sum_gradient, double sum_hessian,
                                            uint32_t threshold, data_size_t num_data,
                                            double parent_output, SplitInfo* output) {
    double gain_shift = GetLeafGainGivenOutput<true>(
780
        sum_gradient, sum_hessian, meta_->config->lambda_l1,
Belinda Trotta's avatar
Belinda Trotta committed
781
        meta_->config->lambda_l2, parent_output);
Guolin Ke's avatar
Guolin Ke committed
782
    double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
783
784

    // do stuff here
785
    const int8_t offset = meta_->offset;
786
787
788
789
790
791

    double sum_right_gradient = 0.0f;
    double sum_right_hessian = kEpsilon;
    data_size_t right_count = 0;

    // set values
792
793
    bool use_na_as_missing = false;
    bool skip_default_bin = false;
794
795
    if (meta_->missing_type == MissingType::Zero) {
      skip_default_bin = true;
796
    } else if (meta_->missing_type == MissingType::NaN) {
797
798
799
      use_na_as_missing = true;
    }

800
801
    int t = meta_->num_bin - 1 - offset - use_na_as_missing;
    const int t_end = 1 - offset;
802
    const double cnt_factor = num_data / sum_hessian;
803
804
    // from right to left, and we don't need data in bin0
    for (; t >= t_end; --t) {
805
      if (static_cast<uint32_t>(t + offset) <= threshold) {
806
807
        break;
      }
808
809

      // need to skip default bin
810
811
812
813
      if (skip_default_bin &&
          (t + offset) == static_cast<int>(meta_->default_bin)) {
        continue;
      }
814
815
      const auto grad = GET_GRAD(data_, t);
      const auto hess = GET_HESS(data_, t);
816
817
      data_size_t cnt =
          static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));
818
819
820
      sum_right_gradient += grad;
      sum_right_hessian += hess;
      right_count += cnt;
821
822
823
824
    }
    double sum_left_gradient = sum_gradient - sum_right_gradient;
    double sum_left_hessian = sum_hessian - sum_right_hessian;
    data_size_t left_count = num_data - right_count;
825
    double current_gain =
Belinda Trotta's avatar
Belinda Trotta committed
826
        GetLeafGain<true, true, USE_SMOOTHING>(
827
            sum_left_gradient, sum_left_hessian, meta_->config->lambda_l1,
Belinda Trotta's avatar
Belinda Trotta committed
828
829
830
            meta_->config->lambda_l2, meta_->config->max_delta_step,
            meta_->config->path_smooth, left_count, parent_output) +
        GetLeafGain<true, true, USE_SMOOTHING>(
831
            sum_right_gradient, sum_right_hessian, meta_->config->lambda_l1,
Belinda Trotta's avatar
Belinda Trotta committed
832
833
            meta_->config->lambda_l2, meta_->config->max_delta_step,
            meta_->config->path_smooth, right_count, parent_output);
834
835
836
837

    // gain with split is worse than without split
    if (std::isnan(current_gain) || current_gain <= min_gain_shift) {
      output->gain = kMinScore;
838
      Log::Warning(
839
          "'Forced Split' will be ignored since the gain getting worse.");
840
      return;
841
    }
842
843
844

    // update split information
    output->threshold = threshold;
Belinda Trotta's avatar
Belinda Trotta committed
845
    output->left_output = CalculateSplittedLeafOutput<true, true, USE_SMOOTHING>(
846
        sum_left_gradient, sum_left_hessian, meta_->config->lambda_l1,
Belinda Trotta's avatar
Belinda Trotta committed
847
848
        meta_->config->lambda_l2, meta_->config->max_delta_step,
        meta_->config->path_smooth, left_count, parent_output);
849
850
851
    output->left_count = left_count;
    output->left_sum_gradient = sum_left_gradient;
    output->left_sum_hessian = sum_left_hessian - kEpsilon;
Belinda Trotta's avatar
Belinda Trotta committed
852
    output->right_output = CalculateSplittedLeafOutput<true, true, USE_SMOOTHING>(
853
854
        sum_gradient - sum_left_gradient, sum_hessian - sum_left_hessian,
        meta_->config->lambda_l1, meta_->config->lambda_l2,
Belinda Trotta's avatar
Belinda Trotta committed
855
856
        meta_->config->max_delta_step, meta_->config->path_smooth,
        right_count, parent_output);
857
858
859
    output->right_count = num_data - left_count;
    output->right_sum_gradient = sum_gradient - sum_left_gradient;
    output->right_sum_hessian = sum_hessian - sum_left_hessian - kEpsilon;
860
    output->gain = current_gain - min_gain_shift;
861
862
863
    output->default_left = true;
  }

Belinda Trotta's avatar
Belinda Trotta committed
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
  void GatherInfoForThresholdCategorical(double sum_gradient,  double sum_hessian,
                                         uint32_t threshold, data_size_t num_data,
                                         double parent_output, SplitInfo* output) {
    bool use_smoothing = meta_->config->path_smooth > kEpsilon;
    if (use_smoothing) {
      GatherInfoForThresholdCategoricalInner<true>(sum_gradient, sum_hessian, threshold,
                                                   num_data, parent_output, output);
    } else {
      GatherInfoForThresholdCategoricalInner<false>(sum_gradient, sum_hessian, threshold,
                                                    num_data, parent_output, output);
    }
  }

  template<bool USE_SMOOTHING>
  void GatherInfoForThresholdCategoricalInner(double sum_gradient,
                                              double sum_hessian, uint32_t threshold,
                                              data_size_t num_data, double parent_output,
                                              SplitInfo* output) {
882
883
    // get SplitInfo for a given one-hot categorical split.
    output->default_left = false;
Belinda Trotta's avatar
Belinda Trotta committed
884
885
    double gain_shift = GetLeafGainGivenOutput<true>(
        sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output);
Guolin Ke's avatar
Guolin Ke committed
886
    double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
887
    if (threshold >= static_cast<uint32_t>(meta_->num_bin) || threshold == 0) {
888
889
890
891
      output->gain = kMinScore;
      Log::Warning("Invalid categorical threshold split");
      return;
    }
892
    const double cnt_factor = num_data / sum_hessian;
893
894
    const auto grad = GET_GRAD(data_, threshold - meta_->offset);
    const auto hess = GET_HESS(data_, threshold - meta_->offset);
895
896
    data_size_t cnt =
        static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));
897

Guolin Ke's avatar
Guolin Ke committed
898
    double l2 = meta_->config->lambda_l2;
899
    data_size_t left_count = cnt;
900
    data_size_t right_count = num_data - left_count;
901
    double sum_left_hessian = hess + kEpsilon;
902
    double sum_right_hessian = sum_hessian - sum_left_hessian;
903
    double sum_left_gradient = grad;
904
905
    double sum_right_gradient = sum_gradient - sum_left_gradient;
    // current split gain
906
    double current_gain =
Belinda Trotta's avatar
Belinda Trotta committed
907
908
909
910
911
912
913
914
915
916
        GetLeafGain<true, true, USE_SMOOTHING>(sum_right_gradient, sum_right_hessian,
                                      meta_->config->lambda_l1, l2,
                                      meta_->config->max_delta_step,
                                      meta_->config->path_smooth, right_count,
                                      parent_output) +
        GetLeafGain<true, true, USE_SMOOTHING>(sum_left_gradient, sum_left_hessian,
                                      meta_->config->lambda_l1, l2,
                                      meta_->config->max_delta_step,
                                      meta_->config->path_smooth, left_count,
                                      parent_output);
917
918
    if (std::isnan(current_gain) || current_gain <= min_gain_shift) {
      output->gain = kMinScore;
919
920
      Log::Warning(
          "'Forced Split' will be ignored since the gain getting worse.");
921
922
      return;
    }
Belinda Trotta's avatar
Belinda Trotta committed
923
    output->left_output = CalculateSplittedLeafOutput<true, true, USE_SMOOTHING>(
924
        sum_left_gradient, sum_left_hessian, meta_->config->lambda_l1, l2,
Belinda Trotta's avatar
Belinda Trotta committed
925
926
        meta_->config->max_delta_step, meta_->config->path_smooth, left_count,
        parent_output);
927
928
929
    output->left_count = left_count;
    output->left_sum_gradient = sum_left_gradient;
    output->left_sum_hessian = sum_left_hessian - kEpsilon;
Belinda Trotta's avatar
Belinda Trotta committed
930
    output->right_output = CalculateSplittedLeafOutput<true, true, USE_SMOOTHING>(
931
        sum_right_gradient, sum_right_hessian, meta_->config->lambda_l1, l2,
Belinda Trotta's avatar
Belinda Trotta committed
932
933
        meta_->config->max_delta_step, meta_->config->path_smooth, right_count,
        parent_output);
934
935
936
937
938
939
940
941
    output->right_count = right_count;
    output->right_sum_gradient = sum_gradient - sum_left_gradient;
    output->right_sum_hessian = sum_right_hessian - kEpsilon;
    output->gain = current_gain - min_gain_shift;
    output->num_cat_threshold = 1;
    output->cat_threshold = std::vector<uint32_t>(1, threshold);
  }

Guolin Ke's avatar
Guolin Ke committed
942
  /*!
943
944
   * \brief Binary size of this histogram
   */
Guolin Ke's avatar
Guolin Ke committed
945
  int SizeOfHistgram() const {
946
    return (meta_->num_bin - meta_->offset) * kHistEntrySize;
Guolin Ke's avatar
Guolin Ke committed
947
948
  }

949
950
951
952
953
954
955
956
  int SizeOfInt32Histgram() const {
    return (meta_->num_bin - meta_->offset) * kInt32HistEntrySize;
  }

  int SizeOfInt16Histgram() const {
    return (meta_->num_bin - meta_->offset) * kInt16HistEntrySize;
  }

Guolin Ke's avatar
Guolin Ke committed
957
  /*!
958
959
   * \brief Restore histogram from memory
   */
Guolin Ke's avatar
Guolin Ke committed
960
  void FromMemory(char* memory_data) {
961
962
    std::memcpy(data_, memory_data,
                (meta_->num_bin - meta_->offset) * kHistEntrySize);
Guolin Ke's avatar
Guolin Ke committed
963
964
  }

965
966
967
968
969
970
971
972
973
974
  void FromMemoryInt32(char* memory_data) {
    std::memcpy(data_, memory_data,
                (meta_->num_bin - meta_->offset) * kInt32HistEntrySize);
  }

  void FromMemoryInt16(char* memory_data) {
    std::memcpy(data_int16_, memory_data,
                (meta_->num_bin - meta_->offset) * kInt16HistEntrySize);
  }

Guolin Ke's avatar
Guolin Ke committed
975
  /*!
976
977
   * \brief True if this histogram can be splitted
   */
Guolin Ke's avatar
Guolin Ke committed
978
979
980
  bool is_splittable() { return is_splittable_; }

  /*!
981
982
   * \brief Set splittable to this histogram
   */
Guolin Ke's avatar
Guolin Ke committed
983
984
  void set_is_splittable(bool val) { is_splittable_ = val; }

985
986
987
988
989
  static double ThresholdL1(double s, double l1) {
    const double reg_s = std::max(0.0, std::fabs(s) - l1);
    return Common::Sign(s) * reg_s;
  }

Belinda Trotta's avatar
Belinda Trotta committed
990
  template <bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
991
992
  static double CalculateSplittedLeafOutput(double sum_gradients,
                                            double sum_hessians, double l1,
Belinda Trotta's avatar
Belinda Trotta committed
993
994
995
996
                                            double l2, double max_delta_step,
                                            double smoothing, data_size_t num_data,
                                            double parent_output) {
    double ret;
997
    if (USE_L1) {
Belinda Trotta's avatar
Belinda Trotta committed
998
      ret = -ThresholdL1(sum_gradients, l1) / (sum_hessians + l2);
999
    } else {
Belinda Trotta's avatar
Belinda Trotta committed
1000
1001
1002
1003
1004
      ret = -sum_gradients / (sum_hessians + l2);
    }
    if (USE_MAX_OUTPUT) {
      if (max_delta_step > 0 && std::fabs(ret) > max_delta_step) {
        ret = Common::Sign(ret) * max_delta_step;
1005
      }
1006
    }
Belinda Trotta's avatar
Belinda Trotta committed
1007
1008
1009
1010
1011
    if (USE_SMOOTHING) {
      ret = ret * (num_data / smoothing) / (num_data / smoothing + 1) \
          + parent_output / (num_data / smoothing + 1);
    }
    return ret;
Guolin Ke's avatar
Guolin Ke committed
1012
1013
  }

Belinda Trotta's avatar
Belinda Trotta committed
1014
  template <bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
1015
1016
  static double CalculateSplittedLeafOutput(
      double sum_gradients, double sum_hessians, double l1, double l2,
1017
      double max_delta_step, const BasicConstraint& constraints,
Belinda Trotta's avatar
Belinda Trotta committed
1018
1019
1020
      double smoothing, data_size_t num_data, double parent_output) {
    double ret = CalculateSplittedLeafOutput<USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
        sum_gradients, sum_hessians, l1, l2, max_delta_step, smoothing, num_data, parent_output);
1021
1022
1023
1024
1025
1026
1027
1028
    if (USE_MC) {
      if (ret < constraints.min) {
        ret = constraints.min;
      } else if (ret > constraints.max) {
        ret = constraints.max;
      }
    }
    return ret;
Guolin Ke's avatar
Guolin Ke committed
1029
1030
  }

1031
 private:
Belinda Trotta's avatar
Belinda Trotta committed
1032
  template <bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
1033
1034
1035
1036
1037
  static double GetSplitGains(double sum_left_gradients,
                              double sum_left_hessians,
                              double sum_right_gradients,
                              double sum_right_hessians, double l1, double l2,
                              double max_delta_step,
1038
                              const FeatureConstraint* constraints,
Belinda Trotta's avatar
Belinda Trotta committed
1039
1040
1041
1042
1043
                              int8_t monotone_constraint,
                              double smoothing,
                              data_size_t left_count,
                              data_size_t right_count,
                              double parent_output) {
1044
    if (!USE_MC) {
Belinda Trotta's avatar
Belinda Trotta committed
1045
1046
1047
1048
1049
1050
1051
1052
      return GetLeafGain<USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(sum_left_gradients,
                                                                sum_left_hessians, l1, l2,
                                                                max_delta_step, smoothing,
                                                                left_count, parent_output) +
             GetLeafGain<USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(sum_right_gradients,
                                                                sum_right_hessians, l1, l2,
                                                                max_delta_step, smoothing,
                                                                right_count, parent_output);
1053
1054
    } else {
      double left_output =
Belinda Trotta's avatar
Belinda Trotta committed
1055
          CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
1056
              sum_left_gradients, sum_left_hessians, l1, l2, max_delta_step,
1057
              constraints->LeftToBasicConstraint(), smoothing, left_count, parent_output);
1058
      double right_output =
Belinda Trotta's avatar
Belinda Trotta committed
1059
          CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
1060
              sum_right_gradients, sum_right_hessians, l1, l2, max_delta_step,
1061
              constraints->RightToBasicConstraint(), smoothing, right_count, parent_output);
1062
1063
1064
1065
1066
1067
1068
1069
      if (((monotone_constraint > 0) && (left_output > right_output)) ||
          ((monotone_constraint < 0) && (left_output < right_output))) {
        return 0;
      }
      return GetLeafGainGivenOutput<USE_L1>(
                 sum_left_gradients, sum_left_hessians, l1, l2, left_output) +
             GetLeafGainGivenOutput<USE_L1>(
                 sum_right_gradients, sum_right_hessians, l1, l2, right_output);
Guolin Ke's avatar
Guolin Ke committed
1070
    }
Guolin Ke's avatar
Guolin Ke committed
1071
  }
Guolin Ke's avatar
Guolin Ke committed
1072

Belinda Trotta's avatar
Belinda Trotta committed
1073
  template <bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING>
1074
  static double GetLeafGain(double sum_gradients, double sum_hessians,
Belinda Trotta's avatar
Belinda Trotta committed
1075
1076
1077
                            double l1, double l2, double max_delta_step,
                            double smoothing, data_size_t num_data, double parent_output) {
    if (!USE_MAX_OUTPUT && !USE_SMOOTHING) {
1078
1079
1080
1081
1082
1083
1084
      if (USE_L1) {
        const double sg_l1 = ThresholdL1(sum_gradients, l1);
        return (sg_l1 * sg_l1) / (sum_hessians + l2);
      } else {
        return (sum_gradients * sum_gradients) / (sum_hessians + l2);
      }
    } else {
Belinda Trotta's avatar
Belinda Trotta committed
1085
1086
1087
      double output = CalculateSplittedLeafOutput<USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
          sum_gradients, sum_hessians, l1, l2, max_delta_step, smoothing, num_data, parent_output);
      return GetLeafGainGivenOutput<USE_L1>(sum_gradients, sum_hessians, l1, l2, output);
1088
    }
Guolin Ke's avatar
Guolin Ke committed
1089
1090
  }

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
  template <bool USE_L1>
  static double GetLeafGainGivenOutput(double sum_gradients,
                                       double sum_hessians, double l1,
                                       double l2, double output) {
    if (USE_L1) {
      const double sg_l1 = ThresholdL1(sum_gradients, l1);
      return -(2.0 * sg_l1 * output + (sum_hessians + l2) * output * output);
    } else {
      return -(2.0 * sum_gradients * output +
               (sum_hessians + l2) * output * output);
    }
Guolin Ke's avatar
Guolin Ke committed
1102
  }
Guolin Ke's avatar
Guolin Ke committed
1103

Belinda Trotta's avatar
Belinda Trotta committed
1104
  template <bool USE_RAND, bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING,
1105
            bool REVERSE, bool SKIP_DEFAULT_BIN, bool NA_AS_MISSING>
guolinke's avatar
guolinke committed
1106
1107
  void FindBestThresholdSequentially(double sum_gradient, double sum_hessian,
                                     data_size_t num_data,
1108
                                     const FeatureConstraint* constraints,
guolinke's avatar
guolinke committed
1109
                                     double min_gain_shift, SplitInfo* output,
Belinda Trotta's avatar
Belinda Trotta committed
1110
                                     int rand_threshold, double parent_output) {
1111
    const int8_t offset = meta_->offset;
Guolin Ke's avatar
Guolin Ke committed
1112
1113
1114
1115
1116
    double best_sum_left_gradient = NAN;
    double best_sum_left_hessian = NAN;
    double best_gain = kMinScore;
    data_size_t best_left_count = 0;
    uint32_t best_threshold = static_cast<uint32_t>(meta_->num_bin);
1117
    const double cnt_factor = num_data / sum_hessian;
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127

    BasicConstraint best_right_constraints;
    BasicConstraint best_left_constraints;
    bool constraint_update_necessary =
        USE_MC && constraints->ConstraintDifferentDependingOnThreshold();

    if (USE_MC) {
      constraints->InitCumulativeConstraints(REVERSE);
    }

1128
    if (REVERSE) {
Guolin Ke's avatar
Guolin Ke committed
1129
1130
1131
1132
      double sum_right_gradient = 0.0f;
      double sum_right_hessian = kEpsilon;
      data_size_t right_count = 0;

1133
      int t = meta_->num_bin - 1 - offset - NA_AS_MISSING;
1134
      const int t_end = 1 - offset;
Guolin Ke's avatar
Guolin Ke committed
1135
1136
1137
1138

      // from right to left, and we don't need data in bin0
      for (; t >= t_end; --t) {
        // need to skip default bin
1139
1140
1141
1142
1143
        if (SKIP_DEFAULT_BIN) {
          if ((t + offset) == static_cast<int>(meta_->default_bin)) {
            continue;
          }
        }
1144
1145
        const auto grad = GET_GRAD(data_, t);
        const auto hess = GET_HESS(data_, t);
1146
1147
        data_size_t cnt =
            static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));
1148
1149
1150
        sum_right_gradient += grad;
        sum_right_hessian += hess;
        right_count += cnt;
Guolin Ke's avatar
Guolin Ke committed
1151
        // if data not enough, or sum hessian too small
1152
        if (right_count < meta_->config->min_data_in_leaf ||
1153
            sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) {
1154
          continue;
1155
        }
Guolin Ke's avatar
Guolin Ke committed
1156
1157
        data_size_t left_count = num_data - right_count;
        // if data not enough
1158
1159
1160
        if (left_count < meta_->config->min_data_in_leaf) {
          break;
        }
Guolin Ke's avatar
Guolin Ke committed
1161
1162
1163

        double sum_left_hessian = sum_hessian - sum_right_hessian;
        // if sum hessian too small
1164
1165
1166
        if (sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) {
          break;
        }
Guolin Ke's avatar
Guolin Ke committed
1167
1168

        double sum_left_gradient = sum_gradient - sum_right_gradient;
1169
        if (USE_RAND) {
1170
          if (t - 1 + offset != rand_threshold) {
Guolin Ke's avatar
Guolin Ke committed
1171
            continue;
1172
          }
Guolin Ke's avatar
Guolin Ke committed
1173
        }
1174
1175
1176
1177
1178

        if (USE_MC && constraint_update_necessary) {
          constraints->Update(t + offset);
        }

Guolin Ke's avatar
Guolin Ke committed
1179
        // current split gain
Belinda Trotta's avatar
Belinda Trotta committed
1180
        double current_gain = GetSplitGains<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
1181
1182
1183
            sum_left_gradient, sum_left_hessian, sum_right_gradient,
            sum_right_hessian, meta_->config->lambda_l1,
            meta_->config->lambda_l2, meta_->config->max_delta_step,
Belinda Trotta's avatar
Belinda Trotta committed
1184
1185
            constraints, meta_->monotone_type, meta_->config->path_smooth,
            left_count, right_count, parent_output);
Guolin Ke's avatar
Guolin Ke committed
1186
        // gain with split is worse than without split
1187
1188
1189
        if (current_gain <= min_gain_shift) {
          continue;
        }
Guolin Ke's avatar
Guolin Ke committed
1190

Andrew Ziem's avatar
Andrew Ziem committed
1191
        // mark as able to be split
Guolin Ke's avatar
Guolin Ke committed
1192
1193
1194
        is_splittable_ = true;
        // better split point
        if (current_gain > best_gain) {
1195
1196
1197
1198
1199
1200
1201
1202
          if (USE_MC) {
            best_right_constraints = constraints->RightToBasicConstraint();
            best_left_constraints = constraints->LeftToBasicConstraint();
            if (best_right_constraints.min > best_right_constraints.max ||
                best_left_constraints.min > best_left_constraints.max) {
              continue;
            }
          }
Guolin Ke's avatar
Guolin Ke committed
1203
1204
1205
1206
1207
1208
1209
          best_left_count = left_count;
          best_sum_left_gradient = sum_left_gradient;
          best_sum_left_hessian = sum_left_hessian;
          // left is <= threshold, right is > threshold.  so this is t-1
          best_threshold = static_cast<uint32_t>(t - 1 + offset);
          best_gain = current_gain;
        }
Guolin Ke's avatar
Guolin Ke committed
1210
      }
ChenZhiyong's avatar
ChenZhiyong committed
1211
    } else {
Guolin Ke's avatar
Guolin Ke committed
1212
1213
1214
1215
1216
      double sum_left_gradient = 0.0f;
      double sum_left_hessian = kEpsilon;
      data_size_t left_count = 0;

      int t = 0;
1217
      const int t_end = meta_->num_bin - 2 - offset;
Guolin Ke's avatar
Guolin Ke committed
1218

1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
      if (NA_AS_MISSING) {
        if (offset == 1) {
          sum_left_gradient = sum_gradient;
          sum_left_hessian = sum_hessian - kEpsilon;
          left_count = num_data;
          for (int i = 0; i < meta_->num_bin - offset; ++i) {
            const auto grad = GET_GRAD(data_, i);
            const auto hess = GET_HESS(data_, i);
            data_size_t cnt =
                static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));
            sum_left_gradient -= grad;
            sum_left_hessian -= hess;
            left_count -= cnt;
          }
          t = -1;
Guolin Ke's avatar
Guolin Ke committed
1234
1235
1236
        }
      }

Guolin Ke's avatar
Guolin Ke committed
1237
      for (; t <= t_end; ++t) {
1238
1239
1240
1241
1242
        if (SKIP_DEFAULT_BIN) {
          if ((t + offset) == static_cast<int>(meta_->default_bin)) {
            continue;
          }
        }
Guolin Ke's avatar
Guolin Ke committed
1243
        if (t >= 0) {
1244
1245
          sum_left_gradient += GET_GRAD(data_, t);
          sum_left_hessian += GET_HESS(data_, t);
1246
1247
          left_count += static_cast<data_size_t>(
              Common::RoundInt(GET_HESS(data_, t) * cnt_factor));
Guolin Ke's avatar
Guolin Ke committed
1248
        }
Guolin Ke's avatar
Guolin Ke committed
1249
        // if data not enough, or sum hessian too small
1250
        if (left_count < meta_->config->min_data_in_leaf ||
1251
            sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) {
1252
          continue;
1253
        }
Guolin Ke's avatar
Guolin Ke committed
1254
1255
        data_size_t right_count = num_data - left_count;
        // if data not enough
1256
1257
1258
        if (right_count < meta_->config->min_data_in_leaf) {
          break;
        }
Guolin Ke's avatar
Guolin Ke committed
1259
1260

        double sum_right_hessian = sum_hessian - sum_left_hessian;
Andrew Ziem's avatar
Andrew Ziem committed
1261
        // if sum Hessian too small
1262
1263
1264
        if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) {
          break;
        }
Guolin Ke's avatar
Guolin Ke committed
1265
1266

        double sum_right_gradient = sum_gradient - sum_left_gradient;
1267
        if (USE_RAND) {
Guolin Ke's avatar
Guolin Ke committed
1268
1269
          if (t + offset != rand_threshold) {
            continue;
1270
          }
Guolin Ke's avatar
Guolin Ke committed
1271
        }
Guolin Ke's avatar
Guolin Ke committed
1272
        // current split gain
Belinda Trotta's avatar
Belinda Trotta committed
1273
        double current_gain = GetSplitGains<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
1274
1275
1276
            sum_left_gradient, sum_left_hessian, sum_right_gradient,
            sum_right_hessian, meta_->config->lambda_l1,
            meta_->config->lambda_l2, meta_->config->max_delta_step,
Belinda Trotta's avatar
Belinda Trotta committed
1277
1278
            constraints, meta_->monotone_type, meta_->config->path_smooth, left_count,
            right_count, parent_output);
Guolin Ke's avatar
Guolin Ke committed
1279
        // gain with split is worse than without split
1280
1281
1282
        if (current_gain <= min_gain_shift) {
          continue;
        }
Guolin Ke's avatar
Guolin Ke committed
1283

Andrew Ziem's avatar
Andrew Ziem committed
1284
        // mark as able to be split
Guolin Ke's avatar
Guolin Ke committed
1285
1286
1287
        is_splittable_ = true;
        // better split point
        if (current_gain > best_gain) {
1288
1289
1290
1291
1292
1293
1294
1295
          if (USE_MC) {
            best_right_constraints = constraints->RightToBasicConstraint();
            best_left_constraints = constraints->LeftToBasicConstraint();
            if (best_right_constraints.min > best_right_constraints.max ||
                best_left_constraints.min > best_left_constraints.max) {
              continue;
            }
          }
Guolin Ke's avatar
Guolin Ke committed
1296
1297
1298
1299
1300
1301
          best_left_count = left_count;
          best_sum_left_gradient = sum_left_gradient;
          best_sum_left_hessian = sum_left_hessian;
          best_threshold = static_cast<uint32_t>(t + offset);
          best_gain = current_gain;
        }
Guolin Ke's avatar
Guolin Ke committed
1302
1303
1304
      }
    }

1305
    if (is_splittable_ && best_gain > output->gain + min_gain_shift) {
Guolin Ke's avatar
Guolin Ke committed
1306
1307
      // update split information
      output->threshold = best_threshold;
1308
      output->left_output =
Belinda Trotta's avatar
Belinda Trotta committed
1309
          CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
1310
1311
              best_sum_left_gradient, best_sum_left_hessian,
              meta_->config->lambda_l1, meta_->config->lambda_l2,
1312
              meta_->config->max_delta_step, best_left_constraints, meta_->config->path_smooth,
Belinda Trotta's avatar
Belinda Trotta committed
1313
              best_left_count, parent_output);
Guolin Ke's avatar
Guolin Ke committed
1314
1315
1316
      output->left_count = best_left_count;
      output->left_sum_gradient = best_sum_left_gradient;
      output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
1317
      output->right_output =
Belinda Trotta's avatar
Belinda Trotta committed
1318
          CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
1319
1320
1321
              sum_gradient - best_sum_left_gradient,
              sum_hessian - best_sum_left_hessian, meta_->config->lambda_l1,
              meta_->config->lambda_l2, meta_->config->max_delta_step,
1322
              best_right_constraints, meta_->config->path_smooth, num_data - best_left_count,
Belinda Trotta's avatar
Belinda Trotta committed
1323
              parent_output);
Guolin Ke's avatar
Guolin Ke committed
1324
1325
      output->right_count = num_data - best_left_count;
      output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
1326
1327
1328
1329
      output->right_sum_hessian =
          sum_hessian - best_sum_left_hessian - kEpsilon;
      output->gain = best_gain - min_gain_shift;
      output->default_left = REVERSE;
Guolin Ke's avatar
Guolin Ke committed
1330
1331
1332
    }
  }

1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
  template <bool USE_RAND, bool USE_MC, bool USE_L1, bool USE_MAX_OUTPUT, bool USE_SMOOTHING,
          bool REVERSE, bool SKIP_DEFAULT_BIN, bool NA_AS_MISSING, typename PACKED_HIST_BIN_T, typename PACKED_HIST_ACC_T,
          typename HIST_BIN_T, typename HIST_ACC_T, int HIST_BITS_BIN, int HIST_BITS_ACC>
  void FindBestThresholdSequentiallyInt(int64_t int_sum_gradient_and_hessian,
                                        const double grad_scale, const double hess_scale,
                                        data_size_t num_data,
                                        const FeatureConstraint* constraints,
                                        double min_gain_shift, SplitInfo* output,
                                        int rand_threshold, double parent_output) {
    const int8_t offset = meta_->offset;
    PACKED_HIST_ACC_T best_sum_left_gradient_and_hessian = 0;
    PACKED_HIST_ACC_T local_int_sum_gradient_and_hessian =
      HIST_BITS_ACC == 16 ?
      ((static_cast<int32_t>(int_sum_gradient_and_hessian >> 32) << 16) | static_cast<int32_t>(int_sum_gradient_and_hessian & 0x0000ffff)) :
      int_sum_gradient_and_hessian;
    double best_gain = kMinScore;
    uint32_t best_threshold = static_cast<uint32_t>(meta_->num_bin);
    const double cnt_factor = static_cast<double>(num_data) /
      static_cast<double>(static_cast<uint32_t>(int_sum_gradient_and_hessian & 0x00000000ffffffff));

    BasicConstraint best_right_constraints;
    BasicConstraint best_left_constraints;
    bool constraint_update_necessary =
        USE_MC && constraints->ConstraintDifferentDependingOnThreshold();

    if (USE_MC) {
      constraints->InitCumulativeConstraints(REVERSE);
    }

    const PACKED_HIST_BIN_T* data_ptr = nullptr;
    if (HIST_BITS_BIN == 16) {
      data_ptr = reinterpret_cast<const PACKED_HIST_BIN_T*>(data_int16_);
    } else {
      data_ptr = reinterpret_cast<const PACKED_HIST_BIN_T*>(data_);
    }
    if (REVERSE) {
      PACKED_HIST_ACC_T sum_right_gradient_and_hessian = 0;

      int t = meta_->num_bin - 1 - offset - NA_AS_MISSING;
      const int t_end = 1 - offset;

      // from right to left, and we don't need data in bin0
      for (; t >= t_end; --t) {
        // need to skip default bin
        if (SKIP_DEFAULT_BIN) {
          if ((t + offset) == static_cast<int>(meta_->default_bin)) {
            continue;
          }
        }
        const PACKED_HIST_BIN_T grad_and_hess = data_ptr[t];
        if (HIST_BITS_ACC != HIST_BITS_BIN) {
          const PACKED_HIST_ACC_T grad_and_hess_acc = HIST_BITS_BIN == 16 ?
            ((static_cast<PACKED_HIST_ACC_T>(static_cast<HIST_BIN_T>(grad_and_hess >> HIST_BITS_BIN)) << HIST_BITS_ACC) |
            (static_cast<PACKED_HIST_ACC_T>(grad_and_hess & 0x0000ffff))) :
            ((static_cast<PACKED_HIST_ACC_T>(static_cast<HIST_BIN_T>(grad_and_hess >> HIST_BITS_BIN)) << HIST_BITS_ACC) |
            (static_cast<PACKED_HIST_ACC_T>(grad_and_hess & 0x00000000ffffffff)));
          sum_right_gradient_and_hessian += grad_and_hess_acc;
        } else {
          sum_right_gradient_and_hessian += grad_and_hess;
        }
        const uint32_t int_sum_right_hessian = HIST_BITS_ACC == 16 ?
          static_cast<uint32_t>(sum_right_gradient_and_hessian & 0x0000ffff) :
          static_cast<uint32_t>(sum_right_gradient_and_hessian & 0x00000000ffffffff);
        data_size_t right_count = Common::RoundInt(int_sum_right_hessian * cnt_factor);
        double sum_right_hessian = int_sum_right_hessian * hess_scale;
        // if data not enough, or sum hessian too small
        if (right_count < meta_->config->min_data_in_leaf ||
            sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) {
          continue;
        }
        data_size_t left_count = num_data - right_count;
        // if data not enough
        if (left_count < meta_->config->min_data_in_leaf) {
          break;
        }

        const PACKED_HIST_ACC_T sum_left_gradient_and_hessian = local_int_sum_gradient_and_hessian - sum_right_gradient_and_hessian;
        const uint32_t int_sum_left_hessian = HIST_BITS_ACC == 16 ?
          static_cast<uint32_t>(sum_left_gradient_and_hessian & 0x0000ffff) :
          static_cast<uint32_t>(sum_left_gradient_and_hessian & 0x00000000ffffffff);
        double sum_left_hessian = int_sum_left_hessian * hess_scale;
        // if sum hessian too small
        if (sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) {
          break;
        }

        double sum_right_gradient = HIST_BITS_ACC == 16 ?
          static_cast<double>(static_cast<int16_t>(sum_right_gradient_and_hessian >> 16)) * grad_scale :
          static_cast<double>(static_cast<int32_t>(sum_right_gradient_and_hessian >> 32)) * grad_scale;
        double sum_left_gradient = HIST_BITS_ACC == 16 ?
          static_cast<double>(static_cast<int16_t>(sum_left_gradient_and_hessian >> 16)) * grad_scale :
          static_cast<double>(static_cast<int32_t>(sum_left_gradient_and_hessian >> 32)) * grad_scale;
        if (USE_RAND) {
          if (t - 1 + offset != rand_threshold) {
            continue;
          }
        }

        if (USE_MC && constraint_update_necessary) {
          constraints->Update(t + offset);
        }

        // current split gain
        double current_gain = GetSplitGains<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
            sum_left_gradient, sum_left_hessian + kEpsilon, sum_right_gradient,
            sum_right_hessian + kEpsilon, meta_->config->lambda_l1,
            meta_->config->lambda_l2, meta_->config->max_delta_step,
            constraints, meta_->monotone_type, meta_->config->path_smooth,
            left_count, right_count, parent_output);
        // gain with split is worse than without split
        if (current_gain <= min_gain_shift) {
          continue;
        }

        // mark as able to be split
        is_splittable_ = true;
        // better split point
        if (current_gain > best_gain) {
          if (USE_MC) {
            best_right_constraints = constraints->RightToBasicConstraint();
            best_left_constraints = constraints->LeftToBasicConstraint();
            if (best_right_constraints.min > best_right_constraints.max ||
                best_left_constraints.min > best_left_constraints.max) {
              continue;
            }
          }
          best_sum_left_gradient_and_hessian = sum_left_gradient_and_hessian;
          // left is <= threshold, right is > threshold.  so this is t-1
          best_threshold = static_cast<uint32_t>(t - 1 + offset);
          best_gain = current_gain;
        }
      }
    } else {
      PACKED_HIST_ACC_T sum_left_gradient_and_hessian = 0;

      int t = 0;
      const int t_end = meta_->num_bin - 2 - offset;

      if (NA_AS_MISSING) {
        if (offset == 1) {
          sum_left_gradient_and_hessian = local_int_sum_gradient_and_hessian;
          for (int i = 0; i < meta_->num_bin - offset; ++i) {
            const PACKED_HIST_BIN_T grad_and_hess = data_ptr[i];
            if (HIST_BITS_ACC != HIST_BITS_BIN) {
              const PACKED_HIST_ACC_T grad_and_hess_acc = HIST_BITS_BIN == 16 ?
                ((static_cast<PACKED_HIST_ACC_T>(static_cast<HIST_BIN_T>(grad_and_hess >> HIST_BITS_BIN)) << HIST_BITS_ACC) |
                (static_cast<PACKED_HIST_ACC_T>(grad_and_hess & 0x0000ffff))) :
                ((static_cast<PACKED_HIST_ACC_T>(static_cast<HIST_BIN_T>(grad_and_hess >> HIST_BITS_BIN)) << HIST_BITS_ACC) |
                (static_cast<PACKED_HIST_ACC_T>(grad_and_hess & 0x00000000ffffffff)));
              sum_left_gradient_and_hessian -= grad_and_hess_acc;
            } else {
              sum_left_gradient_and_hessian -= grad_and_hess;
            }
          }
          t = -1;
        }
      }

      for (; t <= t_end; ++t) {
        if (SKIP_DEFAULT_BIN) {
          if ((t + offset) == static_cast<int>(meta_->default_bin)) {
            continue;
          }
        }
        if (t >= 0) {
          const PACKED_HIST_BIN_T grad_and_hess = data_ptr[t];
          if (HIST_BITS_ACC != HIST_BITS_BIN) {
            const PACKED_HIST_ACC_T grad_and_hess_acc = HIST_BITS_BIN == 16 ?
              ((static_cast<PACKED_HIST_ACC_T>(static_cast<HIST_BIN_T>(grad_and_hess >> HIST_BITS_BIN)) << HIST_BITS_ACC) |
              (static_cast<PACKED_HIST_ACC_T>(grad_and_hess & 0x0000ffff))) :
              ((static_cast<PACKED_HIST_ACC_T>(static_cast<HIST_BIN_T>(grad_and_hess >> HIST_BITS_BIN)) << HIST_BITS_ACC) |
              (static_cast<PACKED_HIST_ACC_T>(grad_and_hess & 0x00000000ffffffff)));
            sum_left_gradient_and_hessian += grad_and_hess_acc;
          } else {
            sum_left_gradient_and_hessian += grad_and_hess;
          }
        }
        // if data not enough, or sum hessian too small
        const uint32_t int_sum_left_hessian = HIST_BITS_ACC == 16 ?
          static_cast<uint32_t>(sum_left_gradient_and_hessian & 0x0000ffff) :
          static_cast<uint32_t>(sum_left_gradient_and_hessian & 0x00000000ffffffff);
        const data_size_t left_count = Common::RoundInt(static_cast<double>(int_sum_left_hessian) * cnt_factor);
        const double sum_left_hessian = static_cast<double>(int_sum_left_hessian) * hess_scale;
        if (left_count < meta_->config->min_data_in_leaf ||
            sum_left_hessian < meta_->config->min_sum_hessian_in_leaf) {
          continue;
        }
        data_size_t right_count = num_data - left_count;
        // if data not enough
        if (right_count < meta_->config->min_data_in_leaf) {
          break;
        }

        const PACKED_HIST_ACC_T sum_right_gradient_and_hessian = local_int_sum_gradient_and_hessian - sum_left_gradient_and_hessian;
        const uint32_t int_sum_right_hessian = HIST_BITS_ACC == 16 ?
          static_cast<uint32_t>(sum_right_gradient_and_hessian & 0x0000ffff) :
          static_cast<uint32_t>(sum_right_gradient_and_hessian & 0x00000000ffffffff);
        const double sum_right_hessian = static_cast<double>(int_sum_right_hessian) * hess_scale;
        // if sum Hessian too small
        if (sum_right_hessian < meta_->config->min_sum_hessian_in_leaf) {
          break;
        }

        double sum_right_gradient = HIST_BITS_ACC == 16 ?
          static_cast<double>(static_cast<int16_t>(sum_right_gradient_and_hessian >> 16)) * grad_scale :
          static_cast<double>(static_cast<int32_t>(sum_right_gradient_and_hessian >> 32)) * grad_scale;
        double sum_left_gradient = HIST_BITS_ACC == 16 ?
          static_cast<double>(static_cast<int16_t>(sum_left_gradient_and_hessian >> 16)) * grad_scale :
          static_cast<double>(static_cast<int32_t>(sum_left_gradient_and_hessian >> 32)) * grad_scale;
        if (USE_RAND) {
          if (t + offset != rand_threshold) {
            continue;
          }
        }
        // current split gain
        double current_gain = GetSplitGains<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
            sum_left_gradient, sum_left_hessian + kEpsilon, sum_right_gradient,
            sum_right_hessian + kEpsilon, meta_->config->lambda_l1,
            meta_->config->lambda_l2, meta_->config->max_delta_step,
            constraints, meta_->monotone_type, meta_->config->path_smooth, left_count,
            right_count, parent_output);
        // gain with split is worse than without split
        if (current_gain <= min_gain_shift) {
          continue;
        }

        // mark as able to be split
        is_splittable_ = true;
        // better split point
        if (current_gain > best_gain) {
          if (USE_MC) {
            best_right_constraints = constraints->RightToBasicConstraint();
            best_left_constraints = constraints->LeftToBasicConstraint();
            if (best_right_constraints.min > best_right_constraints.max ||
                best_left_constraints.min > best_left_constraints.max) {
              continue;
            }
          }
          best_sum_left_gradient_and_hessian = sum_left_gradient_and_hessian;
          best_threshold = static_cast<uint32_t>(t + offset);
          best_gain = current_gain;
        }
      }
    }

    if (is_splittable_ && best_gain > output->gain + min_gain_shift) {
      const int32_t int_best_sum_left_gradient = HIST_BITS_ACC == 16 ?
        static_cast<int32_t>(static_cast<int16_t>(best_sum_left_gradient_and_hessian >> 16)) :
        static_cast<int32_t>(best_sum_left_gradient_and_hessian >> 32);
      const uint32_t int_best_sum_left_hessian = HIST_BITS_ACC == 16 ?
        static_cast<uint32_t>(best_sum_left_gradient_and_hessian & 0x0000ffff) :
        static_cast<uint32_t>(best_sum_left_gradient_and_hessian & 0x00000000ffffffff);
      const double best_sum_left_gradient = static_cast<double>(int_best_sum_left_gradient) * grad_scale;
      const double best_sum_left_hessian = static_cast<double>(int_best_sum_left_hessian) * hess_scale;
      const int64_t best_sum_left_gradient_and_hessian_int64 = HIST_BITS_ACC == 16 ?
          ((static_cast<int64_t>(static_cast<int16_t>(best_sum_left_gradient_and_hessian >> 16)) << 32) |
          static_cast<int64_t>(best_sum_left_gradient_and_hessian & 0x0000ffff)) :
          best_sum_left_gradient_and_hessian;
      const int64_t best_sum_right_gradient_and_hessian = int_sum_gradient_and_hessian - best_sum_left_gradient_and_hessian_int64;
      const int32_t int_best_sum_right_gradient = static_cast<int32_t>(best_sum_right_gradient_and_hessian >> 32);
      const uint32_t int_best_sum_right_hessian = static_cast<uint32_t>(best_sum_right_gradient_and_hessian & 0x00000000ffffffff);
      const double best_sum_right_gradient = static_cast<double>(int_best_sum_right_gradient) * grad_scale;
      const double best_sum_right_hessian = static_cast<double>(int_best_sum_right_hessian) * hess_scale;
      const data_size_t best_left_count = Common::RoundInt(static_cast<double>(int_best_sum_left_hessian) * cnt_factor);
      const data_size_t best_right_count = Common::RoundInt(static_cast<double>(int_best_sum_right_hessian) * cnt_factor);
      // update split information
      output->threshold = best_threshold;
      output->left_output =
          CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
              best_sum_left_gradient, best_sum_left_hessian,
              meta_->config->lambda_l1, meta_->config->lambda_l2,
              meta_->config->max_delta_step, best_left_constraints, meta_->config->path_smooth,
              best_left_count, parent_output);
      output->left_count = best_left_count;
      output->left_sum_gradient = best_sum_left_gradient;
      output->left_sum_hessian = best_sum_left_hessian;
      output->left_sum_gradient_and_hessian = best_sum_left_gradient_and_hessian_int64;
      output->right_output =
          CalculateSplittedLeafOutput<USE_MC, USE_L1, USE_MAX_OUTPUT, USE_SMOOTHING>(
              best_sum_right_gradient,
              best_sum_right_hessian, meta_->config->lambda_l1,
              meta_->config->lambda_l2, meta_->config->max_delta_step,
              best_right_constraints, meta_->config->path_smooth, best_right_count,
              parent_output);
      output->right_count = best_right_count;
      output->right_sum_gradient = best_sum_right_gradient;
      output->right_sum_hessian = best_sum_right_hessian;
      output->right_sum_gradient_and_hessian = best_sum_right_gradient_and_hessian;
      output->gain = best_gain - min_gain_shift;
      output->default_left = REVERSE;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
1626
  const FeatureMetainfo* meta_;
Guolin Ke's avatar
Guolin Ke committed
1627
  /*! \brief sum of gradient of each bin */
1628
  hist_t* data_;
1629
  int16_t* data_int16_;
Guolin Ke's avatar
Guolin Ke committed
1630
  bool is_splittable_ = true;
1631

1632
  std::function<void(double, double, data_size_t, const FeatureConstraint*,
Belinda Trotta's avatar
Belinda Trotta committed
1633
                     double, SplitInfo*)>
1634
      find_best_threshold_fun_;
1635
1636
1637
1638

  std::function<void(int64_t, double, double, const uint8_t, const uint8_t, data_size_t, const FeatureConstraint*,
                     double, SplitInfo*)>
      int_find_best_threshold_fun_;
Guolin Ke's avatar
Guolin Ke committed
1639
};
Nikita Titov's avatar
Nikita Titov committed
1640

Guolin Ke's avatar
Guolin Ke committed
1641
class HistogramPool {
1642
 public:
Guolin Ke's avatar
Guolin Ke committed
1643
  /*!
1644
1645
   * \brief Constructor
   */
Guolin Ke's avatar
Guolin Ke committed
1646
  HistogramPool() {
Guolin Ke's avatar
Guolin Ke committed
1647
1648
    cache_size_ = 0;
    total_size_ = 0;
Guolin Ke's avatar
Guolin Ke committed
1649
  }
1650

Guolin Ke's avatar
Guolin Ke committed
1651
  /*!
1652
1653
1654
   * \brief Destructor
   */
  ~HistogramPool() {}
1655

Guolin Ke's avatar
Guolin Ke committed
1656
  /*!
1657
1658
1659
1660
   * \brief Reset pool size
   * \param cache_size Max cache size
   * \param total_size Total size will be used
   */
Guolin Ke's avatar
Guolin Ke committed
1661
  void Reset(int cache_size, int total_size) {
Guolin Ke's avatar
Guolin Ke committed
1662
1663
    cache_size_ = cache_size;
    // at least need 2 bucket to store smaller leaf and larger leaf
1664
    CHECK_GE(cache_size_, 2);
Guolin Ke's avatar
Guolin Ke committed
1665
1666
1667
1668
1669
1670
    total_size_ = total_size;
    if (cache_size_ > total_size_) {
      cache_size_ = total_size_;
    }
    is_enough_ = (cache_size_ == total_size_);
    if (!is_enough_) {
1671
1672
1673
      mapper_.resize(total_size_);
      inverse_mapper_.resize(cache_size_);
      last_used_time_.resize(cache_size_);
Guolin Ke's avatar
Guolin Ke committed
1674
1675
1676
      ResetMap();
    }
  }
1677

Guolin Ke's avatar
Guolin Ke committed
1678
  /*!
1679
1680
   * \brief Reset mapper
   */
Guolin Ke's avatar
Guolin Ke committed
1681
1682
1683
1684
1685
1686
1687
1688
  void ResetMap() {
    if (!is_enough_) {
      cur_time_ = 0;
      std::fill(mapper_.begin(), mapper_.end(), -1);
      std::fill(inverse_mapper_.begin(), inverse_mapper_.end(), -1);
      std::fill(last_used_time_.begin(), last_used_time_.end(), 0);
    }
  }
1689
1690
1691
  template <bool USE_DATA, bool USE_CONFIG>
  static void SetFeatureInfo(const Dataset* train_data, const Config* config,
                             std::vector<FeatureMetainfo>* feature_meta) {
1692
1693
1694
    auto& ref_feature_meta = *feature_meta;
    const int num_feature = train_data->num_features();
    ref_feature_meta.resize(num_feature);
1695
#pragma omp parallel for schedule(static, 512) if (num_feature >= 1024)
1696
    for (int i = 0; i < num_feature; ++i) {
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
      if (USE_DATA) {
        ref_feature_meta[i].num_bin = train_data->FeatureNumBin(i);
        ref_feature_meta[i].default_bin =
            train_data->FeatureBinMapper(i)->GetDefaultBin();
        ref_feature_meta[i].missing_type =
            train_data->FeatureBinMapper(i)->missing_type();
        if (train_data->FeatureBinMapper(i)->GetMostFreqBin() == 0) {
          ref_feature_meta[i].offset = 1;
        } else {
          ref_feature_meta[i].offset = 0;
        }
        ref_feature_meta[i].bin_type =
            train_data->FeatureBinMapper(i)->bin_type();
1710
      }
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
      if (USE_CONFIG) {
        const int real_fidx = train_data->RealFeatureIndex(i);
        if (!config->monotone_constraints.empty()) {
          ref_feature_meta[i].monotone_type =
              config->monotone_constraints[real_fidx];
        } else {
          ref_feature_meta[i].monotone_type = 0;
        }
        if (!config->feature_contri.empty()) {
          ref_feature_meta[i].penalty = config->feature_contri[real_fidx];
        } else {
          ref_feature_meta[i].penalty = 1.0;
        }
        ref_feature_meta[i].rand = Random(config->extra_seed + i);
1725
1726
1727
1728
1729
      }
      ref_feature_meta[i].config = config;
    }
  }

1730
1731
1732
  void DynamicChangeSize(const Dataset* train_data, int num_total_bin,
                        const std::vector<uint32_t>& offsets, const Config* config,
                        int cache_size, int total_size) {
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
    if (feature_metas_.empty()) {
      SetFeatureInfo<true, true>(train_data, config, &feature_metas_);
      uint64_t bin_cnt_over_features = 0;
      for (int i = 0; i < train_data->num_features(); ++i) {
        bin_cnt_over_features +=
            static_cast<uint64_t>(feature_metas_[i].num_bin);
      }
      Log::Info("Total Bins %d", bin_cnt_over_features);
    }
    int old_cache_size = static_cast<int>(pool_.size());
    Reset(cache_size, total_size);

    if (cache_size > old_cache_size) {
      pool_.resize(cache_size);
      data_.resize(cache_size);
    }
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774

    if (config->use_quantized_grad) {
      OMP_INIT_EX();
      #pragma omp parallel for schedule(static)
      for (int i = old_cache_size; i < cache_size; ++i) {
        OMP_LOOP_EX_BEGIN();
        pool_[i].reset(new FeatureHistogram[train_data->num_features()]);
        data_[i].resize(num_total_bin);
        for (int j = 0; j < train_data->num_features(); ++j) {
          int16_t* data_ptr = reinterpret_cast<int16_t*>(data_[i].data());
          pool_[i][j].Init(data_[i].data() + offsets[j], data_ptr + 2 * offsets[j], &feature_metas_[j]);
        }
        OMP_LOOP_EX_END();
      }
      OMP_THROW_EX();
    } else {
      OMP_INIT_EX();
      #pragma omp parallel for schedule(static)
      for (int i = old_cache_size; i < cache_size; ++i) {
        OMP_LOOP_EX_BEGIN();
        pool_[i].reset(new FeatureHistogram[train_data->num_features()]);
        data_[i].resize(num_total_bin * 2);
        for (int j = 0; j < train_data->num_features(); ++j) {
          pool_[i][j].Init(data_[i].data() + offsets[j] * 2, &feature_metas_[j]);
        }
        OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1775
      }
1776
      OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1777
1778
1779
    }
  }

1780
  void ResetConfig(const Dataset* train_data, const Config* config) {
1781
1782
    CHECK_GT(train_data->num_features(), 0);
    const Config* old_config = feature_metas_[0].config;
1783
    SetFeatureInfo<false, true>(train_data, config, &feature_metas_);
1784
1785
1786
1787
    // if need to reset the function pointers
    if (old_config->lambda_l1 != config->lambda_l1 ||
        old_config->monotone_constraints != config->monotone_constraints ||
        old_config->extra_trees != config->extra_trees ||
Belinda Trotta's avatar
Belinda Trotta committed
1788
1789
        old_config->max_delta_step != config->max_delta_step ||
        old_config->path_smooth != config->path_smooth) {
1790
1791
1792
1793
1794
1795
1796
#pragma omp parallel for schedule(static)
      for (int i = 0; i < cache_size_; ++i) {
        for (int j = 0; j < train_data->num_features(); ++j) {
          pool_[i][j].ResetFunc();
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1797
  }
1798

Guolin Ke's avatar
Guolin Ke committed
1799
  /*!
1800
1801
1802
1803
1804
1805
   * \brief Get data for the specific index
   * \param idx which index want to get
   * \param out output data will store into this
   * \return True if this index is in the pool, False if this index is not in
   * the pool
   */
Guolin Ke's avatar
Guolin Ke committed
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
  bool Get(int idx, FeatureHistogram** out) {
    if (is_enough_) {
      *out = pool_[idx].get();
      return true;
    } else if (mapper_[idx] >= 0) {
      int slot = mapper_[idx];
      *out = pool_[slot].get();
      last_used_time_[slot] = ++cur_time_;
      return true;
    } else {
1816
      // choose the least used slot
Guolin Ke's avatar
Guolin Ke committed
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
      int slot = static_cast<int>(ArrayArgs<int>::ArgMin(last_used_time_));
      *out = pool_[slot].get();
      last_used_time_[slot] = ++cur_time_;

      // reset previous mapper
      if (inverse_mapper_[slot] >= 0) mapper_[inverse_mapper_[slot]] = -1;

      // update current mapper
      mapper_[idx] = slot;
      inverse_mapper_[slot] = idx;
      return false;
    }
  }

  /*!
1832
1833
1834
1835
   * \brief Move data from one index to another index
   * \param src_idx
   * \param dst_idx
   */
Guolin Ke's avatar
Guolin Ke committed
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
  void Move(int src_idx, int dst_idx) {
    if (is_enough_) {
      std::swap(pool_[src_idx], pool_[dst_idx]);
      return;
    }
    if (mapper_[src_idx] < 0) {
      return;
    }
    // get slot of src idx
    int slot = mapper_[src_idx];
    // reset src_idx
    mapper_[src_idx] = -1;

    // move to dst idx
    mapper_[dst_idx] = slot;
    last_used_time_[slot] = ++cur_time_;
    inverse_mapper_[slot] = dst_idx;
  }
1854

1855
 private:
Guolin Ke's avatar
Guolin Ke committed
1856
  std::vector<std::unique_ptr<FeatureHistogram[]>> pool_;
1857
1858
1859
  std::vector<
      std::vector<hist_t, Common::AlignmentAllocator<hist_t, kAlignedSize>>>
      data_;
Guolin Ke's avatar
Guolin Ke committed
1860
  std::vector<FeatureMetainfo> feature_metas_;
Guolin Ke's avatar
Guolin Ke committed
1861
1862
1863
1864
1865
1866
1867
1868
1869
  int cache_size_;
  int total_size_;
  bool is_enough_ = false;
  std::vector<int> mapper_;
  std::vector<int> inverse_mapper_;
  std::vector<int> last_used_time_;
  int cur_time_ = 0;
};

Guolin Ke's avatar
Guolin Ke committed
1870
}  // namespace LightGBM
1871
#endif  // LightGBM_TREELEARNER_FEATURE_HISTOGRAM_HPP_