Unverified Commit 7e81d9e2 authored by sbruch's avatar sbruch Committed by GitHub
Browse files

Optimize the computation of the cross-entropy ranking loss (#3080)

* Fix loss computation

* fix test

* Optimize ranking loss computation.
parent dea2391b
...@@ -305,46 +305,53 @@ class RankXENDCG : public RankingObjective { ...@@ -305,46 +305,53 @@ class RankXENDCG : public RankingObjective {
const label_t* label, const double* score, const label_t* label, const double* score,
score_t* lambdas, score_t* lambdas,
score_t* hessians) const override { score_t* hessians) const override {
// Skip groups with too few items.
if (cnt <= 1) {
for (data_size_t i = 0; i < cnt; ++i) {
lambdas[i] = 0.0f;
hessians[i] = 0.0f;
}
return;
}
// Turn scores into a probability distribution using Softmax. // Turn scores into a probability distribution using Softmax.
std::vector<double> rho(cnt, 0.0); std::vector<double> rho(cnt, 0.0);
Common::Softmax(score, rho.data(), cnt); Common::Softmax(score, rho.data(), cnt);
// used for Phi and L1 // An auxiliary buffer of parameters used to form the ground-truth
std::vector<double> l1s(cnt); // distribution and compute the loss.
double sum_labels = 0; std::vector<double> params(cnt);
double inv_denominator = 0;
for (data_size_t i = 0; i < cnt; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
l1s[i] = Phi(label[i], rands_[query_id].NextFloat()); params[i] = Phi(label[i], rands_[query_id].NextFloat());
sum_labels += l1s[i]; inv_denominator += params[i];
} }
// sum_labels will always be positive number // sum_labels will always be positive number
sum_labels = std::max<double>(kEpsilon, sum_labels); inv_denominator = 1. / std::max<double>(kEpsilon, inv_denominator);
// Approximate gradients and inverse Hessian. // Approximate gradients and inverse Hessian.
// First order terms. // First order terms.
double sum_l1 = 0.0f; double sum_l1 = 0.0;
for (data_size_t i = 0; i < cnt; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
l1s[i] = -l1s[i] / sum_labels + rho[i]; double term = -params[i] * inv_denominator + rho[i];
sum_l1 += l1s[i] / (1. - rho[i]); lambdas[i] = static_cast<score_t>(term);
// Params will now store terms needed to compute second-order terms.
params[i] = term / (1. - rho[i]);
sum_l1 += params[i];
} }
if (cnt <= 1) { // Second order terms.
// when cnt <= 1, the l2 and l3 are zeros double sum_l2 = 0.0;
for (data_size_t i = 0; i < cnt; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
lambdas[i] = static_cast<score_t>(l1s[i]); double term = rho[i] * (sum_l1 - params[i]);
hessians[i] = static_cast<score_t>(rho[i] * (1.0 - rho[i])); lambdas[i] += static_cast<score_t>(term);
} // Params will now store terms needed to compute third-order terms.
} else { params[i] = term / (1. - rho[i]);
// Second order terms. sum_l2 += params[i];
std::vector<double> l2s(cnt, 0.0); }
double sum_l2 = 0.0; for (data_size_t i = 0; i < cnt; ++i) {
for (data_size_t i = 0; i < cnt; ++i) { lambdas[i] += static_cast<score_t>(rho[i] * (sum_l2 - params[i]));
l2s[i] = sum_l1 - (l1s[i] / (1. - rho[i])); hessians[i] = static_cast<score_t>(rho[i] * (1.0 - rho[i]));
sum_l2 += l2s[i] * rho[i] / (1. - rho[i]);
}
for (data_size_t i = 0; i < cnt; ++i) {
auto l3 = sum_l2 - (l2s[i] * rho[i] / (1. - rho[i]));
lambdas[i] = static_cast<score_t>(l1s[i] + rho[i] * l2s[i] +
rho[i] * l3);
hessians[i] = static_cast<score_t>(rho[i] * (1.0 - rho[i]));
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment