regression_metric.hpp 3.26 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#ifndef LIGHTGBM_METRIC_REGRESSION_METRIC_HPP_
#define LIGHTGBM_METRIC_REGRESSION_METRIC_HPP_

#include <LightGBM/utils/log.h>

#include <LightGBM/metric.h>

#include <cmath>

namespace LightGBM {
/*!
* \brief Metric for regression task.
* Use static class "PointWiseLossCalculator" to calculate loss point-wise
*/
template<typename PointWiseLossCalculator>
class RegressionMetric: public Metric {
public:
18
19
  explicit RegressionMetric(const MetricConfig&) {

Guolin Ke's avatar
Guolin Ke committed
20
21
22
23
24
25
  }

  virtual ~RegressionMetric() {

  }

26
27
28
29
30
31
32
33
  const char* GetName() const override {
    return name_.c_str();
  }

  bool is_bigger_better() const override {
    return false;
  }

Guolin Ke's avatar
Guolin Ke committed
34
  void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
35
36
37
38
    std::stringstream str_buf;
    str_buf << test_name << "'s " << PointWiseLossCalculator::Name();
    name_ = str_buf.str();

Guolin Ke's avatar
Guolin Ke committed
39
40
41
42
43
44
    num_data_ = num_data;
    // get label
    label_ = metadata.label();
    // get weights
    weights_ = metadata.weights();
    if (weights_ == nullptr) {
45
      sum_weights_ = static_cast<float>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
52
    } else {
      sum_weights_ = 0.0f;
      for (data_size_t i = 0; i < num_data_; ++i) {
        sum_weights_ += weights_[i];
      }
    }
  }
53
54
55
56
57
58
59
60

  std::vector<float> Eval(const score_t* score) const override {
    score_t sum_loss = 0.0f;
    if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
      for (data_size_t i = 0; i < num_data_; ++i) {
        // add loss
        sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]);
Guolin Ke's avatar
Guolin Ke committed
61
      }
62
63
64
65
66
    } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
      for (data_size_t i = 0; i < num_data_; ++i) {
        // add loss
        sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i];
wxchan's avatar
wxchan committed
67
      }
Guolin Ke's avatar
Guolin Ke committed
68
    }
69
70
71
    score_t loss = PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_);
    return std::vector<float>(1, loss);

Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  }

  inline static score_t AverageLoss(score_t sum_loss, score_t sum_weights) {
    return sum_loss / sum_weights;
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Pointer of weighs */
  const float* weights_;
  /*! \brief Sum weights */
86
  float sum_weights_;
Guolin Ke's avatar
Guolin Ke committed
87
  /*! \brief Name of this test set */
88
  std::string name_;
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
};

/*! \brief L2 loss for regression task */
class L2Metric: public RegressionMetric<L2Metric> {
public:
  explicit L2Metric(const MetricConfig& config) :RegressionMetric<L2Metric>(config) {}

  inline static score_t LossOnPoint(float label, score_t score) {
    return (score - label)*(score - label);
  }

  inline static score_t AverageLoss(score_t sum_loss, score_t sum_weights) {
    // need sqrt the result for L2 loss
    return std::sqrt(sum_loss / sum_weights);
  }

  inline static const char* Name() {
    return "l2 loss";
  }
};

/*! \brief L1 loss for regression task */
class L1Metric: public RegressionMetric<L1Metric> {
public:
  explicit L1Metric(const MetricConfig& config) :RegressionMetric<L1Metric>(config) {}

  inline static score_t LossOnPoint(float label, score_t score) {
    return std::fabs(score - label);
  }
  inline static const char* Name() {
    return "l1 loss";
  }
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
124
#endif   // LightGBM_METRIC_REGRESSION_METRIC_HPP_