"include/vscode:/vscode.git/clone" did not exist on "bbc45fed573b246556055e4d4dc335b92442d603"
dcg_calculator.cpp 4.55 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/metric.h>

#include <LightGBM/utils/log.h>

#include <cmath>

#include <vector>
#include <algorithm>

namespace LightGBM {

/*! \brief Declaration for some static members */
13
14
std::vector<double> DCGCalculator::label_gain_;
std::vector<double> DCGCalculator::discount_;
Guolin Ke's avatar
Guolin Ke committed
15
16
const data_size_t DCGCalculator::kMaxPosition = 10000;

17
void DCGCalculator::Init(std::vector<double> input_label_gain) {
Guolin Ke's avatar
Guolin Ke committed
18
  label_gain_.resize(input_label_gain.size());
19
  for(size_t i = 0;i < input_label_gain.size();++i){
Guolin Ke's avatar
Guolin Ke committed
20
    label_gain_[i] = static_cast<double>(input_label_gain[i]);
21
  }
Guolin Ke's avatar
Guolin Ke committed
22
  discount_.resize(kMaxPosition);
Guolin Ke's avatar
Guolin Ke committed
23
  for (data_size_t i = 0; i < kMaxPosition; ++i) {
Guolin Ke's avatar
Guolin Ke committed
24
    discount_[i] = 1.0f / std::log2(2.0f + i);
Guolin Ke's avatar
Guolin Ke committed
25
26
27
  }
}

28
double DCGCalculator::CalMaxDCGAtK(data_size_t k, const label_t* label, data_size_t num_data) {
29
  double ret = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
  // counts for all labels
  std::vector<data_size_t> label_cnt(label_gain_.size(), 0);
  for (data_size_t i = 0; i < num_data; ++i) {
    ++label_cnt[static_cast<int>(label[i])];
  }
35
  int top_label = static_cast<int>(label_gain_.size()) - 1;
Guolin Ke's avatar
Guolin Ke committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

  if (k > num_data) { k = num_data; }
  //  start from top label, and accumulate DCG
  for (data_size_t j = 0; j < k; ++j) {
    while (top_label > 0 && label_cnt[top_label] <= 0) {
      top_label -= 1;
    }
    if (top_label < 0) {
      break;
    }
    ret += discount_[j] * label_gain_[top_label];
    label_cnt[top_label] -= 1;
  }
  return ret;
}

void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks,
53
                              const label_t* label,
Guolin Ke's avatar
Guolin Ke committed
54
                              data_size_t num_data,
55
                              std::vector<double>* out) {
Guolin Ke's avatar
Guolin Ke committed
56
57
58
59
60
  std::vector<data_size_t> label_cnt(label_gain_.size(), 0);
  // counts for all labels
  for (data_size_t i = 0; i < num_data; ++i) {
    ++label_cnt[static_cast<int>(label[i])];
  }
61
  double cur_result = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
62
  data_size_t cur_left = 0;
63
  int top_label = static_cast<int>(label_gain_.size()) - 1;
Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
  // calculate k Max DCG by one pass
  for (size_t i = 0; i < ks.size(); ++i) {
    data_size_t cur_k = ks[i];
    if (cur_k > num_data) { cur_k = num_data; }
    for (data_size_t j = cur_left; j < cur_k; ++j) {
      while (top_label > 0 && label_cnt[top_label] <= 0) {
        top_label -= 1;
      }
      if (top_label < 0) {
        break;
      }
      cur_result += discount_[j] * label_gain_[top_label];
      label_cnt[top_label] -= 1;
    }
    (*out)[i] = cur_result;
    cur_left = cur_k;
  }
}


84
double DCGCalculator::CalDCGAtK(data_size_t k, const label_t* label,
85
                                const double* score, data_size_t num_data) {
Guolin Ke's avatar
Guolin Ke committed
86
  // get sorted indices by score
87
  std::vector<data_size_t> sorted_idx(num_data);
Guolin Ke's avatar
Guolin Ke committed
88
  for (data_size_t i = 0; i < num_data; ++i) {
89
    sorted_idx[i] = i;
Guolin Ke's avatar
Guolin Ke committed
90
91
92
93
94
  }
  std::sort(sorted_idx.begin(), sorted_idx.end(),
           [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });

  if (k > num_data) { k = num_data; }
95
  double dcg = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
100
101
102
103
  // calculate dcg
  for (data_size_t i = 0; i < k; ++i) {
    data_size_t idx = sorted_idx[i];
    dcg += label_gain_[static_cast<int>(label[idx])] * discount_[i];
  }
  return dcg;
}

104
void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const label_t* label,
105
                           const double * score, data_size_t num_data, std::vector<double>* out) {
Guolin Ke's avatar
Guolin Ke committed
106
  // get sorted indices by score
107
  std::vector<data_size_t> sorted_idx(num_data);
Guolin Ke's avatar
Guolin Ke committed
108
  for (data_size_t i = 0; i < num_data; ++i) {
109
    sorted_idx[i] = i;
Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
  }
  std::sort(sorted_idx.begin(), sorted_idx.end(),
            [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });

114
  double cur_result = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  data_size_t cur_left = 0;
  // calculate multi dcg by one pass
  for (size_t i = 0; i < ks.size(); ++i) {
    data_size_t cur_k = ks[i];
    if (cur_k > num_data) { cur_k = num_data; }
    for (data_size_t j = cur_left; j < cur_k; ++j) {
      data_size_t idx = sorted_idx[j];
      cur_result += label_gain_[static_cast<int>(label[idx])] * discount_[j];
    }
    (*out)[i] = cur_result;
    cur_left = cur_k;
  }
}

129
void DCGCalculator::CheckLabel(const label_t* label, data_size_t num_data) {
130
  for (data_size_t i = 0; i < num_data; ++i) {
131
    label_t delta = std::fabs(label[i] - static_cast<int>(label[i]));
132
    if (delta > kEpsilon) {
133
134
      Log::Fatal("label should be int type (met %f) for ranking task,\n"
                 "for the gain of label, please set the label_gain parameter", label[i]);
135
136
137
138
139
140
141
    }
    if (static_cast<size_t>(label[i]) >= label_gain_.size() || label[i] < 0) {
      Log::Fatal("label (%d) excel the max range %d", label[i], label_gain_.size());
    }
  }
}

Guolin Ke's avatar
Guolin Ke committed
142
}  // namespace LightGBM