"...AutoBuildImmortalWrt.git" did not exist on "a84f9f627f9088f4f773cf9ad289554ef08af18d"
Unverified Commit 4580393f authored by ashok-ponnuswami-msft's avatar ashok-ponnuswami-msft Committed by GitHub
Browse files

Range check for DCG position discount lookup (#4069)

* Add check to prevent out of index lookup in the position discount table. Add debug logging to report number of queries found in the data.

* Change debug logging location so that we can print the data file name as well.

* Revert "Change debug logging location so that we can print the data file name as well."

This reverts commit 3981b34bd6e0530f89c4733e78e6b6603bf50d48.

* Add data file name to debug logging.

* Move log line to a place where it is output even when query IDs are read from a separate file.

* Also add the out-of-range check to rank metrics.

* Perform check after number of queries is initialized.

* Update
parent e9f50a59
...@@ -103,6 +103,14 @@ class DCGCalculator { ...@@ -103,6 +103,14 @@ class DCGCalculator {
static double CalMaxDCGAtK(data_size_t k, static double CalMaxDCGAtK(data_size_t k,
const label_t* label, data_size_t num_data); const label_t* label, data_size_t num_data);
/*!
* \brief Check the metadata for NDCG and lambdarank
* \param metadata Metadata
* \param num_queries Number of queries
*/
static void CheckMetadata(const Metadata& metadata, data_size_t num_queries);
/*! /*!
* \brief Check the label range for NDCG and lambdarank * \brief Check the label range for NDCG and lambdarank
* \param label Pointer of label * \param label Pointer of label
......
...@@ -277,6 +277,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -277,6 +277,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// re-load query weight // re-load query weight
LoadQueryWeights(); LoadQueryWeights();
} }
if (num_queries_ > 0) {
Log::Debug("Number of queries in %s: %i. Average number of rows per query: %f.",
data_filename_.c_str(), static_cast<int>(num_queries_), static_cast<double>(num_data_) / num_queries_);
}
} }
void Metadata::SetInitScore(const double* init_score, data_size_t len) { void Metadata::SetInitScore(const double* init_score, data_size_t len) {
......
...@@ -152,6 +152,19 @@ void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const label_t* la ...@@ -152,6 +152,19 @@ void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const label_t* la
} }
} }
void DCGCalculator::CheckMetadata(const Metadata& metadata, data_size_t num_queries) {
const data_size_t* query_boundaries = metadata.query_boundaries();
if (num_queries > 0 && query_boundaries != nullptr) {
for (data_size_t i = 0; i < num_queries; i++) {
data_size_t num_rows = query_boundaries[i + 1] - query_boundaries[i];
if (num_rows > kMaxPosition) {
Log::Fatal("Number of rows %i exceeds upper limit of %i for a query", static_cast<int>(num_rows), static_cast<int>(kMaxPosition));
}
}
}
}
void DCGCalculator::CheckLabel(const label_t* label, data_size_t num_data) { void DCGCalculator::CheckLabel(const label_t* label, data_size_t num_data) {
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
label_t delta = std::fabs(label[i] - static_cast<int>(label[i])); label_t delta = std::fabs(label[i] - static_cast<int>(label[i]));
......
...@@ -37,13 +37,14 @@ class NDCGMetric:public Metric { ...@@ -37,13 +37,14 @@ class NDCGMetric:public Metric {
num_data_ = num_data; num_data_ = num_data;
// get label // get label
label_ = metadata.label(); label_ = metadata.label();
num_queries_ = metadata.num_queries();
DCGCalculator::CheckMetadata(metadata, num_queries_);
DCGCalculator::CheckLabel(label_, num_data_); DCGCalculator::CheckLabel(label_, num_data_);
// get query boundaries // get query boundaries
query_boundaries_ = metadata.query_boundaries(); query_boundaries_ = metadata.query_boundaries();
if (query_boundaries_ == nullptr) { if (query_boundaries_ == nullptr) {
Log::Fatal("The NDCG metric requires query information"); Log::Fatal("The NDCG metric requires query information");
} }
num_queries_ = metadata.num_queries();
// get query weights // get query weights
query_weights_ = metadata.query_weights(); query_weights_ = metadata.query_weights();
if (query_weights_ == nullptr) { if (query_weights_ == nullptr) {
......
...@@ -120,6 +120,7 @@ class LambdarankNDCG : public RankingObjective { ...@@ -120,6 +120,7 @@ class LambdarankNDCG : public RankingObjective {
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
RankingObjective::Init(metadata, num_data); RankingObjective::Init(metadata, num_data);
DCGCalculator::CheckMetadata(metadata, num_queries_);
DCGCalculator::CheckLabel(label_, num_data_); DCGCalculator::CheckLabel(label_, num_data_);
inverse_max_dcgs_.resize(num_queries_); inverse_max_dcgs_.resize(num_queries_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment