binary_metric.hpp 7.27 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#ifndef LIGHTGBM_METRIC_BINARY_METRIC_HPP_
#define LIGHTGBM_METRIC_BINARY_METRIC_HPP_

#include <LightGBM/utils/log.h>
5
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
6
7
8
9
10

#include <LightGBM/metric.h>

#include <algorithm>
#include <vector>
11
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
12
13
14
15
16
17
18
19
20
21

namespace LightGBM {

/*!
* \brief Metric for binary classification task.
* Use static class "PointWiseLossCalculator" to calculate loss point-wise
*/
template<typename PointWiseLossCalculator>
class BinaryMetric: public Metric {
public:
22
23
  explicit BinaryMetric(const MetricConfig&) {

Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
29
  }

  virtual ~BinaryMetric() {

  }

Guolin Ke's avatar
Guolin Ke committed
30
31
  void Init(const Metadata& metadata, data_size_t num_data) override {
    name_.emplace_back(PointWiseLossCalculator::Name());
32

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
    num_data_ = num_data;
    // get label
    label_ = metadata.label();

    // get weights
    weights_ = metadata.weights();

    if (weights_ == nullptr) {
41
      sum_weights_ = static_cast<double>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
42
43
44
45
46
47
48
49
    } else {
      sum_weights_ = 0.0f;
      for (data_size_t i = 0; i < num_data; ++i) {
        sum_weights_ += weights_[i];
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
50
  const std::vector<std::string>& GetName() const override {
51
    return name_;
52
53
  }

54
  double factor_to_bigger_better() const override {
55
    return -1.0f;
56
57
  }

Guolin Ke's avatar
Guolin Ke committed
58
  std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override {
59
    double sum_loss = 0.0f;
60
61
62
63
64
65
66
67
68
69
70
71
72
    if (objective == nullptr) {
      if (weights_ == nullptr) {
        #pragma omp parallel for schedule(static) reduction(+:sum_loss)
        for (data_size_t i = 0; i < num_data_; ++i) {
          // add loss
          sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]);
        }
      } else {
        #pragma omp parallel for schedule(static) reduction(+:sum_loss)
        for (data_size_t i = 0; i < num_data_; ++i) {
          // add loss
          sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i];
        }
Guolin Ke's avatar
Guolin Ke committed
73
      }
74
    } else {
75
76
77
      if (weights_ == nullptr) {
        #pragma omp parallel for schedule(static) reduction(+:sum_loss)
        for (data_size_t i = 0; i < num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
78
79
          double prob = 0;
          objective->ConvertOutput(&score[i], &prob);
80
81
82
83
84
85
          // add loss
          sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob);
        }
      } else {
        #pragma omp parallel for schedule(static) reduction(+:sum_loss)
        for (data_size_t i = 0; i < num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
86
87
          double prob = 0;
          objective->ConvertOutput(&score[i], &prob);
88
89
90
          // add loss
          sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i];
        }
wxchan's avatar
wxchan committed
91
      }
Guolin Ke's avatar
Guolin Ke committed
92
    }
93
94
    double loss = sum_loss / sum_weights_;
    return std::vector<double>(1, loss);
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
100
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Pointer of label */
101
  const label_t* label_;
Guolin Ke's avatar
Guolin Ke committed
102
  /*! \brief Pointer of weighs */
103
  const label_t* weights_;
Guolin Ke's avatar
Guolin Ke committed
104
  /*! \brief Sum weights */
105
  double sum_weights_;
Guolin Ke's avatar
Guolin Ke committed
106
  /*! \brief Name of test set */
107
  std::vector<std::string> name_;
Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
112
113
114
115
116
};

/*!
* \brief Log loss metric for binary classification task.
*/
class BinaryLoglossMetric: public BinaryMetric<BinaryLoglossMetric> {
public:
  explicit BinaryLoglossMetric(const MetricConfig& config) :BinaryMetric<BinaryLoglossMetric>(config) {}

117
  inline static double LossOnPoint(label_t label, double prob) {
Guolin Ke's avatar
Guolin Ke committed
118
    if (label <= 0) {
Guolin Ke's avatar
Guolin Ke committed
119
120
121
122
123
124
125
126
127
128
129
130
      if (1.0f - prob > kEpsilon) {
        return -std::log(1.0f - prob);
      }
    } else {
      if (prob > kEpsilon) {
        return -std::log(prob);
      }
    }
    return -std::log(kEpsilon);
  }

  inline static const char* Name() {
Guolin Ke's avatar
Guolin Ke committed
131
    return "binary_logloss";
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
136
137
138
139
140
  }
};
/*!
* \brief Error rate metric for binary classification task.
*/
class BinaryErrorMetric: public BinaryMetric<BinaryErrorMetric> {
public:
  explicit BinaryErrorMetric(const MetricConfig& config) :BinaryMetric<BinaryErrorMetric>(config) {}

141
  inline static double LossOnPoint(label_t label, double prob) {
142
    if (prob <= 0.5f) {
Guolin Ke's avatar
Guolin Ke committed
143
      return label > 0;
Guolin Ke's avatar
Guolin Ke committed
144
    } else {
Guolin Ke's avatar
Guolin Ke committed
145
      return label <= 0;
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
    }
  }

  inline static const char* Name() {
Guolin Ke's avatar
Guolin Ke committed
150
    return "binary_error";
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
156
157
158
  }
};

/*!
* \brief Auc Metric for binary classification task.
*/
class AUCMetric: public Metric {
public:
159
160
  explicit AUCMetric(const MetricConfig&) {

Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
165
  }

  virtual ~AUCMetric() {
  }

Guolin Ke's avatar
Guolin Ke committed
166
  const std::vector<std::string>& GetName() const override {
167
    return name_;
168
169
  }

170
  double factor_to_bigger_better() const override {
171
    return 1.0f;
172
173
  }

Guolin Ke's avatar
Guolin Ke committed
174
  void Init(const Metadata& metadata, data_size_t num_data) override {
175
    name_.emplace_back("auc");
176

Guolin Ke's avatar
Guolin Ke committed
177
178
179
180
181
182
183
    num_data_ = num_data;
    // get label
    label_ = metadata.label();
    // get weights
    weights_ = metadata.weights();

    if (weights_ == nullptr) {
184
      sum_weights_ = static_cast<double>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
185
186
187
188
189
190
191
192
    } else {
      sum_weights_ = 0.0f;
      for (data_size_t i = 0; i < num_data; ++i) {
        sum_weights_ += weights_[i];
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
193
  std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
194
195
196
197
198
    // get indices sorted by score, descent order
    std::vector<data_size_t> sorted_idx;
    for (data_size_t i = 0; i < num_data_; ++i) {
      sorted_idx.emplace_back(i);
    }
199
    Common::ParallelSort(sorted_idx.begin(), sorted_idx.end(), [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
200
    // temp sum of postive label
201
    double cur_pos = 0.0f;
202
    // total sum of postive label
203
    double sum_pos = 0.0f;
204
    // accumlate of auc
205
    double accum = 0.0f;
206
    // temp sum of negative label
207
    double cur_neg = 0.0f;
208
    double threshold = score[sorted_idx[0]];
209
    if (weights_ == nullptr) {  // no weights
Guolin Ke's avatar
Guolin Ke committed
210
      for (data_size_t i = 0; i < num_data_; ++i) {
211
        const label_t cur_label = label_[sorted_idx[i]];
212
        const double cur_score = score[sorted_idx[i]];
213
214
215
216
217
218
219
220
        // new threshold
        if (cur_score != threshold) {
          threshold = cur_score;
          // accmulate
          accum += cur_neg*(cur_pos * 0.5f + sum_pos);
          sum_pos += cur_pos;
          // reset
          cur_neg = cur_pos = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
221
        }
Guolin Ke's avatar
Guolin Ke committed
222
223
        cur_neg += (cur_label <= 0);
        cur_pos += (cur_label > 0);
Guolin Ke's avatar
Guolin Ke committed
224
      }
225
226
    } else {  // has weights
      for (data_size_t i = 0; i < num_data_; ++i) {
227
        const label_t cur_label = label_[sorted_idx[i]];
228
        const double cur_score = score[sorted_idx[i]];
229
        const label_t cur_weight = weights_[sorted_idx[i]];
230
231
232
233
234
235
236
237
238
        // new threshold
        if (cur_score != threshold) {
          threshold = cur_score;
          // accmulate
          accum += cur_neg*(cur_pos * 0.5f + sum_pos);
          sum_pos += cur_pos;
          // reset
          cur_neg = cur_pos = 0.0f;
        }
Guolin Ke's avatar
Guolin Ke committed
239
240
        cur_neg += (cur_label <= 0)*cur_weight;
        cur_pos += (cur_label > 0)*cur_weight;
wxchan's avatar
wxchan committed
241
      }
Guolin Ke's avatar
Guolin Ke committed
242
    }
243
244
    accum += cur_neg*(cur_pos * 0.5f + sum_pos);
    sum_pos += cur_pos;
245
    double auc = 1.0f;
246
247
248
    if (sum_pos > 0.0f && sum_pos != sum_weights_) {
      auc = accum / (sum_pos *(sum_weights_ - sum_pos));
    }
249
    return std::vector<double>(1, auc);
Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
254
255
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Pointer of label */
256
  const label_t* label_;
Guolin Ke's avatar
Guolin Ke committed
257
  /*! \brief Pointer of weighs */
258
  const label_t* weights_;
Guolin Ke's avatar
Guolin Ke committed
259
  /*! \brief Sum weights */
260
  double sum_weights_;
Guolin Ke's avatar
Guolin Ke committed
261
  /*! \brief Name of test set */
262
  std::vector<std::string> name_;
Guolin Ke's avatar
Guolin Ke committed
263
264
265
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
266
#endif   // LightGBM_METRIC_BINARY_METRIC_HPP_