prediction_early_stop.cpp 2.41 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
cbecker's avatar
cbecker committed
5
#include <LightGBM/prediction_early_stop.h>
6

cbecker's avatar
cbecker committed
7
8
#include <LightGBM/utils/log.h>

9
#include <limits>
cbecker's avatar
cbecker committed
10
11
#include <algorithm>
#include <cmath>
12
#include <vector>
cbecker's avatar
cbecker committed
13

Guolin Ke's avatar
Guolin Ke committed
14
namespace LightGBM {
cbecker's avatar
cbecker committed
15

16
17
18
19
20
PredictionEarlyStopInstance CreateNone(const PredictionEarlyStopConfig&) {
  return PredictionEarlyStopInstance{
    [](const double*, int) {
    return false;
  },
21
    std::numeric_limits<int>::max()  // make sure the lambda is almost never called
22
  };
cbecker's avatar
cbecker committed
23
24
}

25
26
27
28
29
30
31
32
PredictionEarlyStopInstance CreateMulticlass(const PredictionEarlyStopConfig& config) {
  // margin_threshold will be captured by value
  const double margin_threshold = config.margin_threshold;

  return PredictionEarlyStopInstance{
    [margin_threshold](const double* pred, int sz) {
    if (sz < 2) {
      Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
cbecker's avatar
cbecker committed
33
    }
34
35
36
37
38

    // copy and sort
    std::vector<double> votes(static_cast<size_t>(sz));
    for (int i = 0; i < sz; ++i) {
      votes[i] = pred[i];
cbecker's avatar
cbecker committed
39
    }
40
41
42
43
44
45
    std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());

    const auto margin = votes[0] - votes[1];

    if (margin > margin_threshold) {
      return true;
cbecker's avatar
cbecker committed
46
    }
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    return false;
  },
    config.round_period
  };
}

PredictionEarlyStopInstance CreateBinary(const PredictionEarlyStopConfig& config) {
  // margin_threshold will be captured by value
  const double margin_threshold = config.margin_threshold;

  return PredictionEarlyStopInstance{
    [margin_threshold](const double* pred, int sz) {
    if (sz != 1) {
      Log::Fatal("Binary early stopping needs predictions to be of length one");
    }
    const auto margin = 2.0 * fabs(pred[0]);

    if (margin > margin_threshold) {
      return true;
cbecker's avatar
cbecker committed
67
    }
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    return false;
  },
    config.round_period
  };
}

PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
                                                              const PredictionEarlyStopConfig& config) {
  if (type == "none") {
    return CreateNone(config);
  } else if (type == "multiclass") {
    return CreateMulticlass(config);
  } else if (type == "binary") {
    return CreateBinary(config);
  } else {
    throw std::runtime_error("Unknown early stopping type: " + type);
cbecker's avatar
cbecker committed
85
86
  }
}
87

88
}  // namespace LightGBM