prediction_early_stop.cpp 2.49 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
cbecker's avatar
cbecker committed
5

6
#include <limits>
cbecker's avatar
cbecker committed
7
8
#include <algorithm>
#include <cmath>
9
#include <vector>
cbecker's avatar
cbecker committed
10

11
12
13
14
#include <LightGBM/prediction_early_stop.h>

#include <LightGBM/utils/log.h>

Guolin Ke's avatar
Guolin Ke committed
15
namespace LightGBM {
cbecker's avatar
cbecker committed
16

17
18
19
20
21
PredictionEarlyStopInstance CreateNone(const PredictionEarlyStopConfig&) {
  return PredictionEarlyStopInstance{
    [](const double*, int) {
    return false;
  },
22
    std::numeric_limits<int>::max()  // make sure the lambda is almost never called
23
  };
cbecker's avatar
cbecker committed
24
25
}

26
27
28
29
30
31
32
33
PredictionEarlyStopInstance CreateMulticlass(const PredictionEarlyStopConfig& config) {
  // margin_threshold will be captured by value
  const double margin_threshold = config.margin_threshold;

  return PredictionEarlyStopInstance{
    [margin_threshold](const double* pred, int sz) {
    if (sz < 2) {
      Log::Fatal("Multiclass early stopping needs predictions to be of length two or larger");
cbecker's avatar
cbecker committed
34
    }
35
36
37
38
39

    // copy and sort
    std::vector<double> votes(static_cast<size_t>(sz));
    for (int i = 0; i < sz; ++i) {
      votes[i] = pred[i];
cbecker's avatar
cbecker committed
40
    }
41
42
43
44
45
46
    std::partial_sort(votes.begin(), votes.begin() + 2, votes.end(), std::greater<double>());

    const auto margin = votes[0] - votes[1];

    if (margin > margin_threshold) {
      return true;
cbecker's avatar
cbecker committed
47
    }
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    return false;
  },
    config.round_period
  };
}

PredictionEarlyStopInstance CreateBinary(const PredictionEarlyStopConfig& config) {
  // margin_threshold will be captured by value
  const double margin_threshold = config.margin_threshold;

  return PredictionEarlyStopInstance{
    [margin_threshold](const double* pred, int sz) {
    if (sz != 1) {
      Log::Fatal("Binary early stopping needs predictions to be of length one");
    }
    const auto margin = 2.0 * fabs(pred[0]);

    if (margin > margin_threshold) {
      return true;
cbecker's avatar
cbecker committed
68
    }
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    return false;
  },
    config.round_period
  };
}

PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
                                                              const PredictionEarlyStopConfig& config) {
  if (type == "none") {
    return CreateNone(config);
  } else if (type == "multiclass") {
    return CreateMulticlass(config);
  } else if (type == "binary") {
    return CreateBinary(config);
  } else {
85
    Log::Fatal("Unknown early stopping type: %s", type.c_str());
cbecker's avatar
cbecker committed
86
  }
87
88
89

  // Fix for compiler warnings about reaching end of control
  return CreateNone(config);
cbecker's avatar
cbecker committed
90
}
91

92
}  // namespace LightGBM