prediction_early_stop.cpp 2.46 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
cbecker's avatar
cbecker committed
5
#include <LightGBM/prediction_early_stop.h>
6

cbecker's avatar
cbecker committed
7
8
#include <LightGBM/utils/log.h>

9
#include <limits>
cbecker's avatar
cbecker committed
10
11
#include <algorithm>
#include <cmath>
12
#include <vector>
cbecker's avatar
cbecker committed
13

14
namespace {
cbecker's avatar
cbecker committed
15

16
using namespace LightGBM;
cbecker's avatar
cbecker committed
17

18
19
20
21
22
PredictionEarlyStopInstance CreateNone(const PredictionEarlyStopConfig&) {
  return PredictionEarlyStopInstance{
    [](const double*, int) {
    return false;
  },
23
    std::numeric_limits<int>::max()  // make sure the lambda is almost never called
24
  };
cbecker's avatar
cbecker committed
25
26
}

27
28
29
30
31
32
33
34
PredictionEarlyStopInstance CreateMulticlass(const PredictionEarlyStopConfig& config) {
  // margin_threshold will be captured by value
  const double margin_threshold = config.margin_threshold;

  return PredictionEarlyStopInstance{
    [margin_threshold](const double* pred, int sz) {
    if (sz < 2) {
      Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
cbecker's avatar
cbecker committed
35
    }
36
37
38
39
40

    // copy and sort
    std::vector<double> votes(static_cast<size_t>(sz));
    for (int i = 0; i < sz; ++i) {
      votes[i] = pred[i];
cbecker's avatar
cbecker committed
41
    }
42
43
44
45
46
47
    std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());

    const auto margin = votes[0] - votes[1];

    if (margin > margin_threshold) {
      return true;
cbecker's avatar
cbecker committed
48
    }
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    return false;
  },
    config.round_period
  };
}

PredictionEarlyStopInstance CreateBinary(const PredictionEarlyStopConfig& config) {
  // margin_threshold will be captured by value
  const double margin_threshold = config.margin_threshold;

  return PredictionEarlyStopInstance{
    [margin_threshold](const double* pred, int sz) {
    if (sz != 1) {
      Log::Fatal("Binary early stopping needs predictions to be of length one");
    }
    const auto margin = 2.0 * fabs(pred[0]);

    if (margin > margin_threshold) {
      return true;
cbecker's avatar
cbecker committed
69
    }
70
71
72
73
74
75
76

    return false;
  },
    config.round_period
  };
}

77
}  // namespace
78
79
80
81
82
83
84
85
86
87
88
89
90

namespace LightGBM {

PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
                                                              const PredictionEarlyStopConfig& config) {
  if (type == "none") {
    return CreateNone(config);
  } else if (type == "multiclass") {
    return CreateMulticlass(config);
  } else if (type == "binary") {
    return CreateBinary(config);
  } else {
    throw std::runtime_error("Unknown early stopping type: " + type);
cbecker's avatar
cbecker committed
91
92
  }
}
93

94
}  // namespace LightGBM