regression_metric.hpp 3.51 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#ifndef LIGHTGBM_METRIC_REGRESSION_METRIC_HPP_
#define LIGHTGBM_METRIC_REGRESSION_METRIC_HPP_

#include <LightGBM/utils/log.h>

#include <LightGBM/metric.h>

#include <cmath>

namespace LightGBM {
/*!
* \brief Metric for regression task.
* Use static class "PointWiseLossCalculator" to calculate loss point-wise
*/
template<typename PointWiseLossCalculator>
class RegressionMetric: public Metric {
public:
  explicit RegressionMetric(const MetricConfig& config) {
wxchan's avatar
wxchan committed
19
    early_stopping_round_ = config.early_stopping_round;
Guolin Ke's avatar
Guolin Ke committed
20
    output_freq_ = config.output_freq;
wxchan's avatar
wxchan committed
21
    the_bigger_the_better = false;
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  }

  virtual ~RegressionMetric() {

  }

  void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
    name = test_name;
    num_data_ = num_data;
    // get label
    label_ = metadata.label();
    // get weights
    weights_ = metadata.weights();
    if (weights_ == nullptr) {
      sum_weights_ = static_cast<double>(num_data_);
    } else {
      sum_weights_ = 0.0f;
      for (data_size_t i = 0; i < num_data_; ++i) {
        sum_weights_ += weights_[i];
      }
    }
  }
wxchan's avatar
wxchan committed
44
  
wxchan's avatar
wxchan committed
45
  score_t PrintAndGetLoss(int iter, const score_t* score) const override {
wxchan's avatar
wxchan committed
46
    if (early_stopping_round_ > 0 || (output_freq_ > 0 && iter % output_freq_ == 0)) {
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
      score_t sum_loss = 0.0;
      if (weights_ == nullptr) {
        #pragma omp parallel for schedule(static) reduction(+:sum_loss)
        for (data_size_t i = 0; i < num_data_; ++i) {
          // add loss
          sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]);
        }
      } else {
        #pragma omp parallel for schedule(static) reduction(+:sum_loss)
        for (data_size_t i = 0; i < num_data_; ++i) {
          // add loss
          sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i];
        }
      }
wxchan's avatar
wxchan committed
61
      score_t loss = PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_);
wxchan's avatar
wxchan committed
62
      if (output_freq_ > 0 && iter % output_freq_ == 0){
Qiwei Ye's avatar
Qiwei Ye committed
63
        Log::Info("Iteration:%d, %s's %s : %f", iter, name, PointWiseLossCalculator::Name(), loss);
wxchan's avatar
wxchan committed
64
      }
wxchan's avatar
wxchan committed
65
      return loss;
Guolin Ke's avatar
Guolin Ke committed
66
    }
wxchan's avatar
wxchan committed
67
    return 0.0f;
Guolin Ke's avatar
Guolin Ke committed
68
69
70
71
72
73
74
  }

  inline static score_t AverageLoss(score_t sum_loss, score_t sum_weights) {
    return sum_loss / sum_weights;
  }

private:
Hui Xue's avatar
Hui Xue committed
75
  /*! \brief Output frequency */
Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  int output_freq_;
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Pointer of weighs */
  const float* weights_;
  /*! \brief Sum weights */
  double sum_weights_;
  /*! \brief Name of this test set */
  const char* name;
};

/*! \brief L2 loss for regression task */
class L2Metric: public RegressionMetric<L2Metric> {
public:
  explicit L2Metric(const MetricConfig& config) :RegressionMetric<L2Metric>(config) {}

  inline static score_t LossOnPoint(float label, score_t score) {
    return (score - label)*(score - label);
  }

  inline static score_t AverageLoss(score_t sum_loss, score_t sum_weights) {
    // need sqrt the result for L2 loss
    return std::sqrt(sum_loss / sum_weights);
  }

  inline static const char* Name() {
    return "l2 loss";
  }
};

/*! \brief L1 loss for regression task */
class L1Metric: public RegressionMetric<L1Metric> {
public:
  explicit L1Metric(const MetricConfig& config) :RegressionMetric<L1Metric>(config) {}

  inline static score_t LossOnPoint(float label, score_t score) {
    return std::fabs(score - label);
  }
  inline static const char* Name() {
    return "l1 loss";
  }
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
122
#endif   // LightGBM_METRIC_REGRESSION_METRIC_HPP_