rank_xendcg_objective.hpp 4.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
/*!
 * Copyright (c) 2019 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
#ifndef LIGHTGBM_OBJECTIVE_RANK_XENDCG_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_RANK_XENDCG_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>

#include <string>
#include <vector>

namespace LightGBM {
/*!
* \brief Implementation of the learning-to-rank objective function, XE_NDCG [arxiv.org/abs/1911.09798].
*/
class RankXENDCG: public ObjectiveFunction {
 public:
  explicit RankXENDCG(const Config& config) {
    rand_ = new Random(config.objective_seed);
  }

  explicit RankXENDCG(const std::vector<std::string>&) {
    rand_ = new Random();
  }

  ~RankXENDCG() {
  }
  void Init(const Metadata& metadata, data_size_t) override {
    // get label
    label_ = metadata.label();
    // get boundries
    query_boundaries_ = metadata.query_boundaries();
    if (query_boundaries_ == nullptr) {
      Log::Fatal("RankXENDCG tasks require query information");
    }
    num_queries_ = metadata.num_queries();
  }

  void GetGradients(const double* score, score_t* gradients,
                    score_t* hessians) const override {
    #pragma omp parallel for schedule(guided)
    for (data_size_t i = 0; i < num_queries_; ++i) {
      GetGradientsForOneQuery(score, gradients, hessians, i);
    }
  }

  inline void GetGradientsForOneQuery(
      const double* score,
      score_t* lambdas, score_t* hessians, data_size_t query_id) const {
    // get doc boundary for current query
    const data_size_t start = query_boundaries_[query_id];
    const data_size_t cnt =
      query_boundaries_[query_id + 1] - query_boundaries_[query_id];
    // add pointers with offset
    const label_t* label = label_ + start;
    score += start;
    lambdas += start;
    hessians += start;

    // Turn scores into a probability distribution using Softmax.
    std::vector<double> rho(cnt);
    Common::Softmax(score, &rho[0], cnt);

    // Prepare a vector of gammas, a parameter of the loss.
    std::vector<double> gammas(cnt);
    for (data_size_t i = 0; i < cnt; ++i) {
      gammas[i] = rand_->NextFloat();
    }

    // Skip query if sum of labels is 0.
    float sum_labels = 0;
    for (data_size_t i = 0; i < cnt; ++i) {
76
      sum_labels += static_cast<float>(phi(label[i], gammas[i]));
77
    }
78
    if (std::fabs(sum_labels) < kEpsilon) {
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
      return;
    }

    // Approximate gradients and inverse Hessian.
    // First order terms.
    std::vector<double> L1s(cnt);
    for (data_size_t i = 0; i < cnt; ++i) {
      L1s[i] = -phi(label[i], gammas[i])/sum_labels + rho[i];
    }
    // Second-order terms.
    std::vector<double> L2s(cnt);
    for (data_size_t i = 0; i < cnt; ++i) {
      for (data_size_t j = 0; j < cnt; ++j) {
        if (i == j) continue;
        L2s[i] += L1s[j] / (1 - rho[j]);
      }
    }
    // Third-order terms.
    std::vector<double> L3s(cnt);
    for (data_size_t i = 0; i < cnt; ++i) {
      for (data_size_t j = 0; j < cnt; ++j) {
        if (i == j) continue;
        L3s[i] += rho[j] * L2s[j] / (1 - rho[j]);
      }
    }

    // Finally, prepare lambdas and hessians.
    for (data_size_t i = 0; i < cnt; ++i) {
      lambdas[i] = static_cast<score_t>(
          L1s[i] + rho[i]*L2s[i] + rho[i]*L3s[i]);
      hessians[i] = static_cast<score_t>(rho[i] * (1.0 - rho[i]));
    }
  }

  double phi(const label_t l, double g) const {
114
    return Common::Pow(2, static_cast<int>(l)) - g;
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
  }

  const char* GetName() const override {
    return "rank_xendcg";
  }

  std::string ToString() const override {
    std::stringstream str_buf;
    str_buf << GetName();
    return str_buf.str();
  }

  bool NeedAccuratePrediction() const override { return false; }

 private:
  /*! \brief Number of queries */
  data_size_t num_queries_;
  /*! \brief Pointer of label */
  const label_t* label_;
  /*! \brief Query boundries */
  const data_size_t* query_boundaries_;
  /*! \brief Pseudo-random number generator */
  Random* rand_;
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_RANK_XENDCG_OBJECTIVE_HPP_