metric.h 3.32 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#ifndef LIGHTGBM_METRIC_H_
#define LIGHTGBM_METRIC_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#include <LightGBM/dataset.h>

#include <vector>

namespace LightGBM {

/*!
* \brief The interface of metric.
*        Metric is used to calculate and output metric result on training / validation data.
*/
class Metric {
public:
  /*! \brief virtual destructor */
  virtual ~Metric() {}

  /*!
  * \brief Initialize
  * \param test_name Specific name for this metric, will output on log
  * \param metadata Label data
  * \param num_data Number of data
  */
  virtual void Init(const char* test_name,
    const Metadata& metadata, data_size_t num_data) = 0;

  /*!
Qiwei Ye's avatar
Qiwei Ye committed
31
  * \brief Calcaluting and printing metric result
Guolin Ke's avatar
Guolin Ke committed
32
33
34
  * \param iter Current iteration
  * \param score Current prediction score
  */
wxchan's avatar
wxchan committed
35
  virtual score_t PrintAndGetLoss(int iter, const score_t* score) const = 0;
Guolin Ke's avatar
Guolin Ke committed
36
37
38
39
40
41
42

  /*!
  * \brief Create object of metrics
  * \param type Specific type of metric
  * \param config Config for metric
  */
  static Metric* CreateMetric(const std::string& type, const MetricConfig& config);
wxchan's avatar
wxchan committed
43
44
45

  bool the_bigger_the_better = false;
  int early_stopping_round_ = 0;
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
};

/*!
* \brief Static class, used to calculate DCG score
*/
class DCGCalculator {
public:
  /*!
  * \brief Initial logic
  * \param label_gain Gain for labels, default is 2^i - 1
  */
  static void Init(std::vector<double> label_gain);

  /*!
  * \brief Calculate the DCG score at position k
Qiwei Ye's avatar
Qiwei Ye committed
61
  * \param k The position to evaluate
Guolin Ke's avatar
Guolin Ke committed
62
63
64
65
66
67
68
69
70
71
  * \param label Pointer of label
  * \param score Pointer of score
  * \param num_data Number of data
  * \return The DCG score
  */
  static double CalDCGAtK(data_size_t k, const float* label,
    const score_t* score, data_size_t num_data);

  /*!
  * \brief Calculate the DCG score at multi position
Qiwei Ye's avatar
Qiwei Ye committed
72
  * \param ks The positions to evaluate
Guolin Ke's avatar
Guolin Ke committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  * \param label Pointer of label
  * \param score Pointer of score
  * \param num_data Number of data
  * \param out Output result
  */
  static void CalDCG(const std::vector<data_size_t>& ks,
    const float* label, const score_t* score,
    data_size_t num_data, std::vector<double>* out);

  /*!
  * \brief Calculate the Max DCG score at position k
  * \param k The position want to eval at
  * \param label Pointer of label
  * \param num_data Number of data
  * \return The max DCG score
  */
  static double CalMaxDCGAtK(data_size_t k,
    const float* label, data_size_t num_data);

  /*!
  * \brief Calculate the Max DCG score at multi position
  * \param ks The positions want to eval at
  * \param label Pointer of label
  * \param num_data Number of data
  * \param out Output result
  */
  static void CalMaxDCG(const std::vector<data_size_t>& ks,
    const float* label, data_size_t num_data, std::vector<double>* out);

  /*!
  * \brief Get discount score of position k
  * \param k The position
  * \return The discount of this position
  */
  inline static double GetDiscount(data_size_t k) { return discount_[k]; }

private:
  /*! \brief True if inited, avoid init multi times */
  static bool is_inited_;
  /*! \brief store gains for different label */
  static std::vector<double> label_gain_;
  /*! \brief store discount score for different position */
  static std::vector<double> discount_;
  /*! \brief max position for eval */
  static const data_size_t kMaxPosition;
};


}  // namespace LightGBM


Guolin Ke's avatar
Guolin Ke committed
124
#endif   // LightGBM_METRIC_H_