Commit 486f5db4 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix map metric

parent 2f0fb204
...@@ -56,6 +56,15 @@ public: ...@@ -56,6 +56,15 @@ public:
sum_query_weights_ += query_weights_[i]; sum_query_weights_ += query_weights_[i];
} }
} }
npos_per_query_.resize(num_queries_, 0);
for (data_size_t i = 0; i < num_queries_; ++i) {
for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
if (label_[j] > 0.5f) {
++npos_per_query_[i];
}
}
}
} }
const std::vector<std::string>& GetName() const override { const std::vector<std::string>& GetName() const override {
...@@ -66,7 +75,7 @@ public: ...@@ -66,7 +75,7 @@ public:
return 1.0f; return 1.0f;
} }
void CalMapAtK(std::vector<int> ks, const float* label, void CalMapAtK(std::vector<int> ks, data_size_t npos, const float* label,
const double* score, data_size_t num_data, std::vector<double>* out) const { const double* score, data_size_t num_data, std::vector<double>* out) const {
// get sorted indices by score // get sorted indices by score
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx;
...@@ -80,7 +89,7 @@ public: ...@@ -80,7 +89,7 @@ public:
double sum_ap = 0.0f; double sum_ap = 0.0f;
data_size_t cur_left = 0; data_size_t cur_left = 0;
for (size_t i = 0; i < ks.size(); ++i) { for (size_t i = 0; i < ks.size(); ++i) {
data_size_t cur_k = ks[i]; data_size_t cur_k = static_cast<data_size_t>(ks[i]);
if (cur_k > num_data) { cur_k = num_data; } if (cur_k > num_data) { cur_k = num_data; }
for (data_size_t j = cur_left; j < cur_k; ++j) { for (data_size_t j = cur_left; j < cur_k; ++j) {
data_size_t idx = sorted_idx[j]; data_size_t idx = sorted_idx[j];
...@@ -89,7 +98,11 @@ public: ...@@ -89,7 +98,11 @@ public:
sum_ap += static_cast<double>(num_hit) / (j + 1.0f); sum_ap += static_cast<double>(num_hit) / (j + 1.0f);
} }
} }
(*out)[i] = sum_ap / cur_k; if (npos > 0) {
(*out)[i] = sum_ap / std::min(npos, cur_k);
} else {
(*out)[i] = 1.0f;
}
cur_left = cur_k; cur_left = cur_k;
} }
} }
...@@ -104,7 +117,7 @@ public: ...@@ -104,7 +117,7 @@ public:
#pragma omp parallel for schedule(guided) firstprivate(tmp_map) #pragma omp parallel for schedule(guided) firstprivate(tmp_map)
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
CalMapAtK(eval_at_, label_ + query_boundaries_[i], CalMapAtK(eval_at_, npos_per_query_[i], label_ + query_boundaries_[i],
score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map); score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map);
for (size_t j = 0; j < eval_at_.size(); ++j) { for (size_t j = 0; j < eval_at_.size(); ++j) {
result_buffer_[tid][j] += tmp_map[j]; result_buffer_[tid][j] += tmp_map[j];
...@@ -114,7 +127,7 @@ public: ...@@ -114,7 +127,7 @@ public:
#pragma omp parallel for schedule(guided) firstprivate(tmp_map) #pragma omp parallel for schedule(guided) firstprivate(tmp_map)
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
CalMapAtK(eval_at_, label_ + query_boundaries_[i], CalMapAtK(eval_at_, npos_per_query_[i], label_ + query_boundaries_[i],
score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map); score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map);
for (size_t j = 0; j < eval_at_.size(); ++j) { for (size_t j = 0; j < eval_at_.size(); ++j) {
result_buffer_[tid][j] += tmp_map[j] * query_weights_[i]; result_buffer_[tid][j] += tmp_map[j] * query_weights_[i];
...@@ -150,6 +163,7 @@ private: ...@@ -150,6 +163,7 @@ private:
/*! \brief Number of threads */ /*! \brief Number of threads */
int num_threads_; int num_threads_;
std::vector<std::string> name_; std::vector<std::string> name_;
std::vector<data_size_t> npos_per_query_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment