multiclass_metric.hpp 3.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#ifndef LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
#define LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_

#include <LightGBM/utils/log.h>

#include <LightGBM/metric.h>

#include <cmath>

namespace LightGBM {
/*!
* \brief Metric for multiclass task.
* Use static class "PointWiseLossCalculator" to calculate loss point-wise
*/
template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric {
public:
  explicit MulticlassMetric(const MetricConfig& config) {
      num_class_ = config.num_class;
  }

  virtual ~MulticlassMetric() {

  }

  void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
    std::stringstream str_buf;
28
29
    str_buf << test_name << " : " << PointWiseLossCalculator::Name();
    name_.emplace_back(str_buf.str());
30
31
32
33
34
35
    num_data_ = num_data;
    // get label
    label_ = metadata.label();
    // get weights
    weights_ = metadata.weights();
    if (weights_ == nullptr) {
36
      sum_weights_ = static_cast<double>(num_data_);
37
38
39
40
41
42
43
44
    } else {
      sum_weights_ = 0.0f;
      for (data_size_t i = 0; i < num_data_; ++i) {
        sum_weights_ += weights_[i];
      }
    }
  }
  
45
46
  std::vector<std::string> GetName() const override {
    return name_;
47
48
  }

49
50
  score_t factor_to_bigger_better() const override {
    return -1.0f;
51
52
  }
  
53
54
  std::vector<double> Eval(const score_t* score) const override {
    double sum_loss = 0.0;
55
56
57
    if (weights_ == nullptr) {
      #pragma omp parallel for schedule(static) reduction(+:sum_loss)
      for (data_size_t i = 0; i < num_data_; ++i) {
58
        std::vector<double> rec(num_class_);
59
        for (int k = 0; k < num_class_; ++k) {
60
          rec[k] = static_cast<double>(score[k * num_data_ + i]);
61
62
63
64
65
66
67
        }
        // add loss
        sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
      }
    } else {
      #pragma omp parallel for schedule(static) reduction(+:sum_loss)
      for (data_size_t i = 0; i < num_data_; ++i) {
68
        std::vector<double> rec(num_class_);
69
        for (int k = 0; k < num_class_; ++k) {
70
          rec[k] = static_cast<double>(score[k * num_data_ + i]);
71
72
73
74
75
        }
        // add loss
        sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
      }
    }
76
77
    double loss = sum_loss / sum_weights_;
    return std::vector<double>(1, loss);
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  }

private:
  /*! \brief Output frequency */
  int output_freq_;
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Pointer of weighs */
  const float* weights_;
  /*! \brief Sum weights */
92
  double sum_weights_;
93
  /*! \brief Name of this test set */
94
  std::vector<std::string> name_;
95
96
97
98
99
100
101
};

/*! \brief L2 loss for multiclass task */
class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public:
  explicit MultiErrorMetric(const MetricConfig& config) :MulticlassMetric<MultiErrorMetric>(config) {}

102
  inline static score_t LossOnPoint(float label, std::vector<double> score) {
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    size_t k = static_cast<size_t>(label);
    for (size_t i = 0; i < score.size(); ++i){
        if (i != k && score[i] > score[k]) {
            return 0.0f;
        }
    }
    return 1.0f;
  }

  inline static const char* Name() {
    return "multi error";
  }
};

/*! \brief Logloss for multiclass task */
class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> {
public:
  explicit MultiLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiLoglossMetric>(config) {}

122
  inline static score_t LossOnPoint(float label, std::vector<double> score) {
123
124
125
    size_t k = static_cast<size_t>(label);
    Common::Softmax(&score);
    if (score[k] > kEpsilon) {
126
      return static_cast<score_t>(-std::log(score[k]));
127
128
129
130
131
132
133
134
135
136
137
138
    } else {
      return -std::log(kEpsilon);
    }
  }
  
  inline static const char* Name() {
    return "multi logloss";
  }
};

}  // namespace LightGBM
#endif   // LightGBM_METRIC_MULTICLASS_METRIC_HPP_