prediction_early_stop.cpp 2.26 KB
Newer Older
cbecker's avatar
cbecker committed
1
2
3
4
5
6
7
8
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/log.h>

#include <algorithm>
#include <vector>
#include <cmath>
#include <limits>

9
namespace {
cbecker's avatar
cbecker committed
10

11
using namespace LightGBM;
cbecker's avatar
cbecker committed
12

13
14
15
16
17
18
19
PredictionEarlyStopInstance CreateNone(const PredictionEarlyStopConfig&) {
  return PredictionEarlyStopInstance{
    [](const double*, int) {
    return false;
  },
    std::numeric_limits<int>::max() // make sure the lambda is almost never called
  };
cbecker's avatar
cbecker committed
20
21
}

22
23
24
25
26
27
28
29
PredictionEarlyStopInstance CreateMulticlass(const PredictionEarlyStopConfig& config) {
  // margin_threshold will be captured by value
  const double margin_threshold = config.margin_threshold;

  return PredictionEarlyStopInstance{
    [margin_threshold](const double* pred, int sz) {
    if (sz < 2) {
      Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
cbecker's avatar
cbecker committed
30
    }
31
32
33
34
35

    // copy and sort
    std::vector<double> votes(static_cast<size_t>(sz));
    for (int i = 0; i < sz; ++i) {
      votes[i] = pred[i];
cbecker's avatar
cbecker committed
36
    }
37
38
39
40
41
42
    std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());

    const auto margin = votes[0] - votes[1];

    if (margin > margin_threshold) {
      return true;
cbecker's avatar
cbecker committed
43
    }
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    return false;
  },
    config.round_period
  };
}

PredictionEarlyStopInstance CreateBinary(const PredictionEarlyStopConfig& config) {
  // margin_threshold will be captured by value
  const double margin_threshold = config.margin_threshold;

  return PredictionEarlyStopInstance{
    [margin_threshold](const double* pred, int sz) {
    if (sz != 1) {
      Log::Fatal("Binary early stopping needs predictions to be of length one");
    }
    const auto margin = 2.0 * fabs(pred[0]);

    if (margin > margin_threshold) {
      return true;
cbecker's avatar
cbecker committed
64
    }
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    return false;
  },
    config.round_period
  };
}

}

namespace LightGBM {

PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
                                                              const PredictionEarlyStopConfig& config) {
  if (type == "none") {
    return CreateNone(config);
  } else if (type == "multiclass") {
    return CreateMulticlass(config);
  } else if (type == "binary") {
    return CreateBinary(config);
  } else {
    throw std::runtime_error("Unknown early stopping type: " + type);
cbecker's avatar
cbecker committed
86
87
  }
}
88
89

}