multiclass_metric.hpp 3.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#ifndef LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
#define LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_

#include <LightGBM/utils/log.h>

#include <LightGBM/metric.h>

#include <cmath>

namespace LightGBM {
/*!
* \brief Metric for multiclass task.
* Use static class "PointWiseLossCalculator" to calculate loss point-wise
*/
template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric {
public:
  explicit MulticlassMetric(const MetricConfig& config) {
      num_class_ = config.num_class;
  }

  virtual ~MulticlassMetric() {

  }

  void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
    std::stringstream str_buf;
    str_buf << test_name << "'s " << PointWiseLossCalculator::Name();
    name_ = str_buf.str();
    num_data_ = num_data;
    // get label
    label_ = metadata.label();
    // get weights
    weights_ = metadata.weights();
    if (weights_ == nullptr) {
      sum_weights_ = static_cast<float>(num_data_);
    } else {
      sum_weights_ = 0.0f;
      for (data_size_t i = 0; i < num_data_; ++i) {
        sum_weights_ += weights_[i];
      }
    }
  }
  
  const char* GetName() const override {
    return name_.c_str();
  }

  bool is_bigger_better() const override {
    return false;
  }
  
  std::vector<score_t> Eval(const score_t* score) const override {
    score_t sum_loss = 0.0;
    if (weights_ == nullptr) {
      #pragma omp parallel for schedule(static) reduction(+:sum_loss)
      for (data_size_t i = 0; i < num_data_; ++i) {
        std::vector<score_t> rec(num_class_);
        for (int k = 0; k < num_class_; ++k) {
          rec[k] = score[k * num_data_ + i];
        }
        // add loss
        sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
      }
    } else {
      #pragma omp parallel for schedule(static) reduction(+:sum_loss)
      for (data_size_t i = 0; i < num_data_; ++i) {
        std::vector<score_t> rec(num_class_);
        for (int k = 0; k < num_class_; ++k) {
          rec[k] = score[k * num_data_ + i];
        }
        // add loss
        sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
      }
    }
    score_t loss = sum_loss / sum_weights_;
    return std::vector<score_t>(1, loss);
  }

private:
  /*! \brief Output frequency */
  int output_freq_;
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Pointer of weighs */
  const float* weights_;
  /*! \brief Sum weights */
  float sum_weights_;
  /*! \brief Name of this test set */
  std::string name_;
};

/*! \brief L2 loss for multiclass task */
class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public:
  explicit MultiErrorMetric(const MetricConfig& config) :MulticlassMetric<MultiErrorMetric>(config) {}

  inline static score_t LossOnPoint(float label, std::vector<score_t> score) {
    size_t k = static_cast<size_t>(label);
    for (size_t i = 0; i < score.size(); ++i){
        if (i != k && score[i] > score[k]) {
            return 0.0f;
        }
    }
    return 1.0f;
  }

  inline static const char* Name() {
    return "multi error";
  }
};

/*! \brief Logloss for multiclass task */
class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> {
public:
  explicit MultiLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiLoglossMetric>(config) {}

  inline static score_t LossOnPoint(float label, std::vector<score_t> score) {
    size_t k = static_cast<size_t>(label);
    Common::Softmax(&score);
    if (score[k] > kEpsilon) {
      return -std::log(score[k]);
    } else {
      return -std::log(kEpsilon);
    }
  }
  
  inline static const char* Name() {
    return "multi logloss";
  }
};

}  // namespace LightGBM
#endif   // LightGBM_METRIC_MULTICLASS_METRIC_HPP_