Commit 486f5db4 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix map metric

parent 2f0fb204
......@@ -56,6 +56,15 @@ public:
sum_query_weights_ += query_weights_[i];
}
}
npos_per_query_.resize(num_queries_, 0);
for (data_size_t i = 0; i < num_queries_; ++i) {
for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
if (label_[j] > 0.5f) {
++npos_per_query_[i];
}
}
}
}
const std::vector<std::string>& GetName() const override {
......@@ -66,7 +75,7 @@ public:
return 1.0f;
}
void CalMapAtK(std::vector<int> ks, const float* label,
void CalMapAtK(std::vector<int> ks, data_size_t npos, const float* label,
const double* score, data_size_t num_data, std::vector<double>* out) const {
// get sorted indices by score
std::vector<data_size_t> sorted_idx;
......@@ -80,7 +89,7 @@ public:
double sum_ap = 0.0f;
data_size_t cur_left = 0;
for (size_t i = 0; i < ks.size(); ++i) {
data_size_t cur_k = ks[i];
data_size_t cur_k = static_cast<data_size_t>(ks[i]);
if (cur_k > num_data) { cur_k = num_data; }
for (data_size_t j = cur_left; j < cur_k; ++j) {
data_size_t idx = sorted_idx[j];
......@@ -89,7 +98,11 @@ public:
sum_ap += static_cast<double>(num_hit) / (j + 1.0f);
}
}
(*out)[i] = sum_ap / cur_k;
if (npos > 0) {
(*out)[i] = sum_ap / std::min(npos, cur_k);
} else {
(*out)[i] = 1.0f;
}
cur_left = cur_k;
}
}
......@@ -104,7 +117,7 @@ public:
#pragma omp parallel for schedule(guided) firstprivate(tmp_map)
for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num();
CalMapAtK(eval_at_, label_ + query_boundaries_[i],
CalMapAtK(eval_at_, npos_per_query_[i], label_ + query_boundaries_[i],
score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map);
for (size_t j = 0; j < eval_at_.size(); ++j) {
result_buffer_[tid][j] += tmp_map[j];
......@@ -114,7 +127,7 @@ public:
#pragma omp parallel for schedule(guided) firstprivate(tmp_map)
for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num();
CalMapAtK(eval_at_, label_ + query_boundaries_[i],
CalMapAtK(eval_at_, npos_per_query_[i], label_ + query_boundaries_[i],
score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map);
for (size_t j = 0; j < eval_at_.size(); ++j) {
result_buffer_[tid][j] += tmp_map[j] * query_weights_[i];
......@@ -150,6 +163,7 @@ private:
/*! \brief Number of threads */
int num_threads_;
std::vector<std::string> name_;
std::vector<data_size_t> npos_per_query_;
};
} // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment