Unverified Commit aa78a6b9 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support label as double type (#1120)

parent 162509ae
...@@ -82,9 +82,9 @@ public: ...@@ -82,9 +82,9 @@ public:
void CheckOrPartition(data_size_t num_all_data, void CheckOrPartition(data_size_t num_all_data,
const std::vector<data_size_t>& used_data_indices); const std::vector<data_size_t>& used_data_indices);
void SetLabel(const float* label, data_size_t len); void SetLabel(const label_t* label, data_size_t len);
void SetWeights(const float* weights, data_size_t len); void SetWeights(const label_t* weights, data_size_t len);
void SetQuery(const data_size_t* query, data_size_t len); void SetQuery(const data_size_t* query, data_size_t len);
...@@ -110,14 +110,14 @@ public: ...@@ -110,14 +110,14 @@ public:
* \brief Get pointer of label * \brief Get pointer of label
* \return Pointer of label * \return Pointer of label
*/ */
inline const float* label() const { return label_.data(); } inline const label_t* label() const { return label_.data(); }
/*! /*!
* \brief Set label for one record * \brief Set label for one record
* \param idx Index of this record * \param idx Index of this record
* \param value Label value of this record * \param value Label value of this record
*/ */
inline void SetLabelAt(data_size_t idx, float value) inline void SetLabelAt(data_size_t idx, label_t value)
{ {
label_[idx] = value; label_[idx] = value;
} }
...@@ -127,7 +127,7 @@ public: ...@@ -127,7 +127,7 @@ public:
* \param idx Index of this record * \param idx Index of this record
* \param value Weight value of this record * \param value Weight value of this record
*/ */
inline void SetWeightAt(data_size_t idx, float value) inline void SetWeightAt(data_size_t idx, label_t value)
{ {
weights_[idx] = value; weights_[idx] = value;
} }
...@@ -146,7 +146,7 @@ public: ...@@ -146,7 +146,7 @@ public:
* \brief Get weights, if not exists, will return nullptr * \brief Get weights, if not exists, will return nullptr
* \return Pointer of weights * \return Pointer of weights
*/ */
inline const float* weights() const { inline const label_t* weights() const {
if (!weights_.empty()) { if (!weights_.empty()) {
return weights_.data(); return weights_.data();
} else { } else {
...@@ -179,7 +179,7 @@ public: ...@@ -179,7 +179,7 @@ public:
* \brief Get weights for queries, if not exists, will return nullptr * \brief Get weights for queries, if not exists, will return nullptr
* \return Pointer of weights for queries * \return Pointer of weights for queries
*/ */
inline const float* query_weights() const { inline const label_t* query_weights() const {
if (!query_weights_.empty()) { if (!query_weights_.empty()) {
return query_weights_.data(); return query_weights_.data();
} else { } else {
...@@ -225,13 +225,13 @@ private: ...@@ -225,13 +225,13 @@ private:
/*! \brief Number of weights, used to check correct weight file */ /*! \brief Number of weights, used to check correct weight file */
data_size_t num_weights_; data_size_t num_weights_;
/*! \brief Label data */ /*! \brief Label data */
std::vector<float> label_; std::vector<label_t> label_;
/*! \brief Weights data */ /*! \brief Weights data */
std::vector<float> weights_; std::vector<label_t> weights_;
/*! \brief Query boundaries */ /*! \brief Query boundaries */
std::vector<data_size_t> query_boundaries_; std::vector<data_size_t> query_boundaries_;
/*! \brief Query weights */ /*! \brief Query weights */
std::vector<float> query_weights_; std::vector<label_t> query_weights_;
/*! \brief Number of querys */ /*! \brief Number of querys */
data_size_t num_queries_; data_size_t num_queries_;
/*! \brief Number of Initial score, used to check correct weight file */ /*! \brief Number of Initial score, used to check correct weight file */
......
...@@ -12,8 +12,26 @@ namespace LightGBM { ...@@ -12,8 +12,26 @@ namespace LightGBM {
/*! \brief Type of data size, it is better to use signed type*/ /*! \brief Type of data size, it is better to use signed type*/
typedef int32_t data_size_t; typedef int32_t data_size_t;
// Enable following marco to use double for score_t
// #define SCORE_T_USE_DOUBLE
// Enable following marco to use double for label_t
// #define LABEL_T_USE_DOUBLE
/*! \brief Type of score, and gradients */ /*! \brief Type of score, and gradients */
#ifdef SCORE_T_USE_DOUBLE
typedef double score_t;
#else
typedef float score_t; typedef float score_t;
#endif
/*! \brief Type of metadata, include weight and label */
#ifdef LABEL_T_USE_DOUBLE
typedef double label_t;
#else
typedef float label_t;
#endif
const score_t kMinScore = -std::numeric_limits<score_t>::infinity(); const score_t kMinScore = -std::numeric_limits<score_t>::infinity();
......
...@@ -70,7 +70,7 @@ public: ...@@ -70,7 +70,7 @@ public:
* \param num_data Number of data * \param num_data Number of data
* \return The DCG score * \return The DCG score
*/ */
static double CalDCGAtK(data_size_t k, const float* label, static double CalDCGAtK(data_size_t k, const label_t* label,
const double* score, data_size_t num_data); const double* score, data_size_t num_data);
/*! /*!
...@@ -82,7 +82,7 @@ public: ...@@ -82,7 +82,7 @@ public:
* \param out Output result * \param out Output result
*/ */
static void CalDCG(const std::vector<data_size_t>& ks, static void CalDCG(const std::vector<data_size_t>& ks,
const float* label, const double* score, const label_t* label, const double* score,
data_size_t num_data, std::vector<double>* out); data_size_t num_data, std::vector<double>* out);
/*! /*!
...@@ -93,14 +93,14 @@ public: ...@@ -93,14 +93,14 @@ public:
* \return The max DCG score * \return The max DCG score
*/ */
static double CalMaxDCGAtK(data_size_t k, static double CalMaxDCGAtK(data_size_t k,
const float* label, data_size_t num_data); const label_t* label, data_size_t num_data);
/*! /*!
* \brief Check the label range for NDCG and lambdarank * \brief Check the label range for NDCG and lambdarank
* \param label Pointer of label * \param label Pointer of label
* \param num_data Number of data * \param num_data Number of data
*/ */
static void CheckLabel(const float* label, data_size_t num_data); static void CheckLabel(const label_t* label, data_size_t num_data);
/*! /*!
* \brief Calculate the Max DCG score at multi position * \brief Calculate the Max DCG score at multi position
...@@ -110,7 +110,7 @@ public: ...@@ -110,7 +110,7 @@ public:
* \param out Output result * \param out Output result
*/ */
static void CalMaxDCG(const std::vector<data_size_t>& ks, static void CalMaxDCG(const std::vector<data_size_t>& ks,
const float* label, data_size_t num_data, std::vector<double>* out); const label_t* label, data_size_t num_data, std::vector<double>* out);
/*! /*!
* \brief Get discount score of position k * \brief Get discount score of position k
......
...@@ -295,7 +295,7 @@ void GBDT::Bagging(int iter) { ...@@ -295,7 +295,7 @@ void GBDT::Bagging(int iter) {
* (i) and (ii) could be selected as say "auto_init_score" = 0 or 1 etc.. * (i) and (ii) could be selected as say "auto_init_score" = 0 or 1 etc..
* *
*/ */
double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const float* label, data_size_t num_data) { double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const label_t* label, data_size_t num_data) {
double init_score = 0.0f; double init_score = 0.0f;
bool got_custom = false; bool got_custom = false;
if (fobj != nullptr) { if (fobj != nullptr) {
......
...@@ -164,7 +164,7 @@ public: ...@@ -164,7 +164,7 @@ public:
return boosting_->TrainOneIter(nullptr, nullptr); return boosting_->TrainOneIter(nullptr, nullptr);
} }
bool TrainOneIter(const float* gradients, const float* hessians) { bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return boosting_->TrainOneIter(gradients, hessians); return boosting_->TrainOneIter(gradients, hessians);
} }
...@@ -904,11 +904,15 @@ int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -904,11 +904,15 @@ int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
int* is_finished) { int* is_finished) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
#ifdef SCORE_T_USE_DOUBLE
Log::Fatal("Don't support Custom loss function when enable SCORE_T_USE_DOUBLE.");
#else
if (ref_booster->TrainOneIter(grad, hess)) { if (ref_booster->TrainOneIter(grad, hess)) {
*is_finished = 1; *is_finished = 1;
} else { } else {
*is_finished = 0; *is_finished = 0;
} }
#endif
API_END(); API_END();
} }
......
...@@ -423,9 +423,17 @@ bool Dataset::SetFloatField(const char* field_name, const float* field_data, dat ...@@ -423,9 +423,17 @@ bool Dataset::SetFloatField(const char* field_name, const float* field_data, dat
std::string name(field_name); std::string name(field_name);
name = Common::Trim(name); name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) { if (name == std::string("label") || name == std::string("target")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't Support LABEL_T_USE_DOUBLE.");
#else
metadata_.SetLabel(field_data, num_element); metadata_.SetLabel(field_data, num_element);
#endif
} else if (name == std::string("weight") || name == std::string("weights")) { } else if (name == std::string("weight") || name == std::string("weights")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't Support LABEL_T_USE_DOUBLE.");
#else
metadata_.SetWeights(field_data, num_element); metadata_.SetWeights(field_data, num_element);
#endif
} else { } else {
return false; return false;
} }
...@@ -458,11 +466,19 @@ bool Dataset::GetFloatField(const char* field_name, data_size_t* out_len, const ...@@ -458,11 +466,19 @@ bool Dataset::GetFloatField(const char* field_name, data_size_t* out_len, const
std::string name(field_name); std::string name(field_name);
name = Common::Trim(name); name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) { if (name == std::string("label") || name == std::string("target")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't Support LABEL_T_USE_DOUBLE.");
#else
*out_ptr = metadata_.label(); *out_ptr = metadata_.label();
*out_len = num_data_; *out_len = num_data_;
#endif
} else if (name == std::string("weight") || name == std::string("weights")) { } else if (name == std::string("weight") || name == std::string("weights")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't Support LABEL_T_USE_DOUBLE.");
#else
*out_ptr = metadata_.weights(); *out_ptr = metadata_.weights();
*out_len = num_data_; *out_len = num_data_;
#endif
} else { } else {
return false; return false;
} }
......
...@@ -921,7 +921,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -921,7 +921,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
// parser // parser
parser->ParseOneLine(text_data[i].c_str(), &oneline_features, &tmp_label); parser->ParseOneLine(text_data[i].c_str(), &oneline_features, &tmp_label);
// set label // set label
dataset->metadata_.SetLabelAt(i, static_cast<float>(tmp_label)); dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
// free processed line: // free processed line:
text_data[i].clear(); text_data[i].clear();
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
...@@ -937,7 +937,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -937,7 +937,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second); dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
} else { } else {
if (inner_data.first == weight_idx_) { if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(i, static_cast<float>(inner_data.second)); dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
} else if (inner_data.first == group_idx_) { } else if (inner_data.first == group_idx_) {
dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second)); dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
} }
...@@ -964,7 +964,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -964,7 +964,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]); init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
} }
// set label // set label
dataset->metadata_.SetLabelAt(i, static_cast<float>(tmp_label)); dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
// free processed line: // free processed line:
text_data[i].clear(); text_data[i].clear();
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
...@@ -980,7 +980,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -980,7 +980,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second); dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
} else { } else {
if (inner_data.first == weight_idx_) { if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(i, static_cast<float>(inner_data.second)); dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
} else if (inner_data.first == group_idx_) { } else if (inner_data.first == group_idx_) {
dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second)); dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
} }
...@@ -1025,7 +1025,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* ...@@ -1025,7 +1025,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
} }
} }
// set label // set label
dataset->metadata_.SetLabelAt(start_idx + i, static_cast<float>(tmp_label)); dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
// push data // push data
for (auto& inner_data : oneline_features) { for (auto& inner_data : oneline_features) {
if (inner_data.first >= dataset->num_total_features_) { continue; } if (inner_data.first >= dataset->num_total_features_) { continue; }
...@@ -1037,7 +1037,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* ...@@ -1037,7 +1037,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second); dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
} else { } else {
if (inner_data.first == weight_idx_) { if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(start_idx + i, static_cast<float>(inner_data.second)); dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
} else if (inner_data.first == group_idx_) { } else if (inner_data.first == group_idx_) {
dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second)); dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
} }
......
...@@ -31,13 +31,13 @@ Metadata::~Metadata() { ...@@ -31,13 +31,13 @@ Metadata::~Metadata() {
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) { void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
num_data_ = num_data; num_data_ = num_data;
label_ = std::vector<float>(num_data_); label_ = std::vector<label_t>(num_data_);
if (weight_idx >= 0) { if (weight_idx >= 0) {
if (!weights_.empty()) { if (!weights_.empty()) {
Log::Info("Using weights in data file, ignoring the additional weights file"); Log::Info("Using weights in data file, ignoring the additional weights file");
weights_.clear(); weights_.clear();
} }
weights_ = std::vector<float>(num_data_); weights_ = std::vector<label_t>(num_data_);
num_weights_ = num_data_; num_weights_ = num_data_;
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
...@@ -63,14 +63,14 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) { ...@@ -63,14 +63,14 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) { void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, data_size_t num_used_indices) {
num_data_ = num_used_indices; num_data_ = num_used_indices;
label_ = std::vector<float>(num_used_indices); label_ = std::vector<label_t>(num_used_indices);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_used_indices; i++) { for (data_size_t i = 0; i < num_used_indices; i++) {
label_[i] = fullset.label_[used_indices[i]]; label_[i] = fullset.label_[used_indices[i]];
} }
if (!fullset.weights_.empty()) { if (!fullset.weights_.empty()) {
weights_ = std::vector<float>(num_used_indices); weights_ = std::vector<label_t>(num_used_indices);
num_weights_ = num_used_indices; num_weights_ = num_used_indices;
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_used_indices; i++) { for (data_size_t i = 0; i < num_used_indices; i++) {
...@@ -134,7 +134,7 @@ void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) { ...@@ -134,7 +134,7 @@ void Metadata::PartitionLabel(const std::vector<data_size_t>& used_indices) {
} }
auto old_label = label_; auto old_label = label_;
num_data_ = static_cast<data_size_t>(used_indices.size()); num_data_ = static_cast<data_size_t>(used_indices.size());
label_ = std::vector<float>(num_data_); label_ = std::vector<label_t>(num_data_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = old_label[used_indices[i]]; label_[i] = old_label[used_indices[i]];
...@@ -205,7 +205,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -205,7 +205,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
if (!weights_.empty()) { if (!weights_.empty()) {
auto old_weights = weights_; auto old_weights = weights_;
num_weights_ = num_data_; num_weights_ = num_data_;
weights_ = std::vector<float>(num_data_); weights_ = std::vector<label_t>(num_data_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) { for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
weights_[i] = old_weights[used_data_indices[i]]; weights_[i] = old_weights[used_data_indices[i]];
...@@ -302,7 +302,7 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) { ...@@ -302,7 +302,7 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
init_score_load_from_file_ = false; init_score_load_from_file_ = false;
} }
void Metadata::SetLabel(const float* label, data_size_t len) { void Metadata::SetLabel(const label_t* label, data_size_t len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (label == nullptr) { if (label == nullptr) {
Log::Fatal("label cannot be nullptr"); Log::Fatal("label cannot be nullptr");
...@@ -311,14 +311,14 @@ void Metadata::SetLabel(const float* label, data_size_t len) { ...@@ -311,14 +311,14 @@ void Metadata::SetLabel(const float* label, data_size_t len) {
Log::Fatal("len of label is not same with #data"); Log::Fatal("len of label is not same with #data");
} }
if (!label_.empty()) { label_.clear(); } if (!label_.empty()) { label_.clear(); }
label_ = std::vector<float>(num_data_); label_ = std::vector<label_t>(num_data_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = label[i]; label_[i] = label[i];
} }
} }
void Metadata::SetWeights(const float* weights, data_size_t len) { void Metadata::SetWeights(const label_t* weights, data_size_t len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr // save to nullptr
if (weights == nullptr || len == 0) { if (weights == nullptr || len == 0) {
...@@ -331,7 +331,7 @@ void Metadata::SetWeights(const float* weights, data_size_t len) { ...@@ -331,7 +331,7 @@ void Metadata::SetWeights(const float* weights, data_size_t len) {
} }
if (!weights_.empty()) { weights_.clear(); } if (!weights_.empty()) { weights_.clear(); }
num_weights_ = num_data_; num_weights_ = num_data_;
weights_ = std::vector<float>(num_weights_); weights_ = std::vector<label_t>(num_weights_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = weights[i]; weights_[i] = weights[i];
...@@ -379,12 +379,12 @@ void Metadata::LoadWeights() { ...@@ -379,12 +379,12 @@ void Metadata::LoadWeights() {
} }
Log::Info("Loading weights..."); Log::Info("Loading weights...");
num_weights_ = static_cast<data_size_t>(reader.Lines().size()); num_weights_ = static_cast<data_size_t>(reader.Lines().size());
weights_ = std::vector<float>(num_weights_); weights_ = std::vector<label_t>(num_weights_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
double tmp_weight = 0.0f; double tmp_weight = 0.0f;
Common::Atof(reader.Lines()[i].c_str(), &tmp_weight); Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
weights_[i] = static_cast<float>(tmp_weight); weights_[i] = static_cast<label_t>(tmp_weight);
} }
weight_load_from_file_ = true; weight_load_from_file_ = true;
} }
...@@ -463,7 +463,7 @@ void Metadata::LoadQueryWeights() { ...@@ -463,7 +463,7 @@ void Metadata::LoadQueryWeights() {
} }
query_weights_.clear(); query_weights_.clear();
Log::Info("Loading query weights..."); Log::Info("Loading query weights...");
query_weights_ = std::vector<float>(num_queries_); query_weights_ = std::vector<label_t>(num_queries_);
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
query_weights_[i] = 0.0f; query_weights_[i] = 0.0f;
for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) { for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
...@@ -484,15 +484,15 @@ void Metadata::LoadFromMemory(const void* memory) { ...@@ -484,15 +484,15 @@ void Metadata::LoadFromMemory(const void* memory) {
mem_ptr += sizeof(num_queries_); mem_ptr += sizeof(num_queries_);
if (!label_.empty()) { label_.clear(); } if (!label_.empty()) { label_.clear(); }
label_ = std::vector<float>(num_data_); label_ = std::vector<label_t>(num_data_);
std::memcpy(label_.data(), mem_ptr, sizeof(float)*num_data_); std::memcpy(label_.data(), mem_ptr, sizeof(label_t)*num_data_);
mem_ptr += sizeof(float)*num_data_; mem_ptr += sizeof(label_t)*num_data_;
if (num_weights_ > 0) { if (num_weights_ > 0) {
if (!weights_.empty()) { weights_.clear(); } if (!weights_.empty()) { weights_.clear(); }
weights_ = std::vector<float>(num_weights_); weights_ = std::vector<label_t>(num_weights_);
std::memcpy(weights_.data(), mem_ptr, sizeof(float)*num_weights_); std::memcpy(weights_.data(), mem_ptr, sizeof(label_t)*num_weights_);
mem_ptr += sizeof(float)*num_weights_; mem_ptr += sizeof(label_t)*num_weights_;
weight_load_from_file_ = true; weight_load_from_file_ = true;
} }
if (num_queries_ > 0) { if (num_queries_ > 0) {
...@@ -509,9 +509,9 @@ void Metadata::SaveBinaryToFile(FILE* file) const { ...@@ -509,9 +509,9 @@ void Metadata::SaveBinaryToFile(FILE* file) const {
fwrite(&num_data_, sizeof(num_data_), 1, file); fwrite(&num_data_, sizeof(num_data_), 1, file);
fwrite(&num_weights_, sizeof(num_weights_), 1, file); fwrite(&num_weights_, sizeof(num_weights_), 1, file);
fwrite(&num_queries_, sizeof(num_queries_), 1, file); fwrite(&num_queries_, sizeof(num_queries_), 1, file);
fwrite(label_.data(), sizeof(float), num_data_, file); fwrite(label_.data(), sizeof(label_t), num_data_, file);
if (!weights_.empty()) { if (!weights_.empty()) {
fwrite(weights_.data(), sizeof(float), num_weights_, file); fwrite(weights_.data(), sizeof(label_t), num_weights_, file);
} }
if (!query_boundaries_.empty()) { if (!query_boundaries_.empty()) {
fwrite(query_boundaries_.data(), sizeof(data_size_t), num_queries_ + 1, file); fwrite(query_boundaries_.data(), sizeof(data_size_t), num_queries_ + 1, file);
...@@ -522,9 +522,9 @@ void Metadata::SaveBinaryToFile(FILE* file) const { ...@@ -522,9 +522,9 @@ void Metadata::SaveBinaryToFile(FILE* file) const {
size_t Metadata::SizesInByte() const { size_t Metadata::SizesInByte() const {
size_t size = sizeof(num_data_) + sizeof(num_weights_) size_t size = sizeof(num_data_) + sizeof(num_weights_)
+ sizeof(num_queries_); + sizeof(num_queries_);
size += sizeof(float) * num_data_; size += sizeof(label_t) * num_data_;
if (!weights_.empty()) { if (!weights_.empty()) {
size += sizeof(float) * num_weights_; size += sizeof(label_t) * num_weights_;
} }
if (!query_boundaries_.empty()) { if (!query_boundaries_.empty()) {
size += sizeof(data_size_t) * (num_queries_ + 1); size += sizeof(data_size_t) * (num_queries_ + 1);
......
...@@ -98,9 +98,9 @@ private: ...@@ -98,9 +98,9 @@ private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Pointer of weighs */ /*! \brief Pointer of weighs */
const float* weights_; const label_t* weights_;
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of test set */ /*! \brief Name of test set */
...@@ -114,7 +114,7 @@ class BinaryLoglossMetric: public BinaryMetric<BinaryLoglossMetric> { ...@@ -114,7 +114,7 @@ class BinaryLoglossMetric: public BinaryMetric<BinaryLoglossMetric> {
public: public:
explicit BinaryLoglossMetric(const MetricConfig& config) :BinaryMetric<BinaryLoglossMetric>(config) {} explicit BinaryLoglossMetric(const MetricConfig& config) :BinaryMetric<BinaryLoglossMetric>(config) {}
inline static double LossOnPoint(float label, double prob) { inline static double LossOnPoint(label_t label, double prob) {
if (label <= 0) { if (label <= 0) {
if (1.0f - prob > kEpsilon) { if (1.0f - prob > kEpsilon) {
return -std::log(1.0f - prob); return -std::log(1.0f - prob);
...@@ -138,7 +138,7 @@ class BinaryErrorMetric: public BinaryMetric<BinaryErrorMetric> { ...@@ -138,7 +138,7 @@ class BinaryErrorMetric: public BinaryMetric<BinaryErrorMetric> {
public: public:
explicit BinaryErrorMetric(const MetricConfig& config) :BinaryMetric<BinaryErrorMetric>(config) {} explicit BinaryErrorMetric(const MetricConfig& config) :BinaryMetric<BinaryErrorMetric>(config) {}
inline static double LossOnPoint(float label, double prob) { inline static double LossOnPoint(label_t label, double prob) {
if (prob <= 0.5f) { if (prob <= 0.5f) {
return label > 0; return label > 0;
} else { } else {
...@@ -208,7 +208,7 @@ public: ...@@ -208,7 +208,7 @@ public:
double threshold = score[sorted_idx[0]]; double threshold = score[sorted_idx[0]];
if (weights_ == nullptr) { // no weights if (weights_ == nullptr) { // no weights
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
const float cur_label = label_[sorted_idx[i]]; const label_t cur_label = label_[sorted_idx[i]];
const double cur_score = score[sorted_idx[i]]; const double cur_score = score[sorted_idx[i]];
// new threshold // new threshold
if (cur_score != threshold) { if (cur_score != threshold) {
...@@ -224,9 +224,9 @@ public: ...@@ -224,9 +224,9 @@ public:
} }
} else { // has weights } else { // has weights
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
const float cur_label = label_[sorted_idx[i]]; const label_t cur_label = label_[sorted_idx[i]];
const double cur_score = score[sorted_idx[i]]; const double cur_score = score[sorted_idx[i]];
const float cur_weight = weights_[sorted_idx[i]]; const label_t cur_weight = weights_[sorted_idx[i]];
// new threshold // new threshold
if (cur_score != threshold) { if (cur_score != threshold) {
threshold = cur_score; threshold = cur_score;
...@@ -253,9 +253,9 @@ private: ...@@ -253,9 +253,9 @@ private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Pointer of weighs */ /*! \brief Pointer of weighs */
const float* weights_; const label_t* weights_;
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of test set */ /*! \brief Name of test set */
......
...@@ -25,7 +25,7 @@ void DCGCalculator::Init(std::vector<double> input_label_gain) { ...@@ -25,7 +25,7 @@ void DCGCalculator::Init(std::vector<double> input_label_gain) {
} }
} }
double DCGCalculator::CalMaxDCGAtK(data_size_t k, const float* label, data_size_t num_data) { double DCGCalculator::CalMaxDCGAtK(data_size_t k, const label_t* label, data_size_t num_data) {
double ret = 0.0f; double ret = 0.0f;
// counts for all labels // counts for all labels
std::vector<data_size_t> label_cnt(label_gain_.size(), 0); std::vector<data_size_t> label_cnt(label_gain_.size(), 0);
...@@ -50,7 +50,7 @@ double DCGCalculator::CalMaxDCGAtK(data_size_t k, const float* label, data_size_ ...@@ -50,7 +50,7 @@ double DCGCalculator::CalMaxDCGAtK(data_size_t k, const float* label, data_size_
} }
void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks, void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks,
const float* label, const label_t* label,
data_size_t num_data, data_size_t num_data,
std::vector<double>* out) { std::vector<double>* out) {
std::vector<data_size_t> label_cnt(label_gain_.size(), 0); std::vector<data_size_t> label_cnt(label_gain_.size(), 0);
...@@ -81,7 +81,7 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks, ...@@ -81,7 +81,7 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks,
} }
double DCGCalculator::CalDCGAtK(data_size_t k, const float* label, double DCGCalculator::CalDCGAtK(data_size_t k, const label_t* label,
const double* score, data_size_t num_data) { const double* score, data_size_t num_data) {
// get sorted indices by score // get sorted indices by score
std::vector<data_size_t> sorted_idx(num_data); std::vector<data_size_t> sorted_idx(num_data);
...@@ -101,7 +101,7 @@ double DCGCalculator::CalDCGAtK(data_size_t k, const float* label, ...@@ -101,7 +101,7 @@ double DCGCalculator::CalDCGAtK(data_size_t k, const float* label,
return dcg; return dcg;
} }
void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* label, void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const label_t* label,
const double * score, data_size_t num_data, std::vector<double>* out) { const double * score, data_size_t num_data, std::vector<double>* out) {
// get sorted indices by score // get sorted indices by score
std::vector<data_size_t> sorted_idx(num_data); std::vector<data_size_t> sorted_idx(num_data);
...@@ -126,9 +126,9 @@ void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* labe ...@@ -126,9 +126,9 @@ void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* labe
} }
} }
void DCGCalculator::CheckLabel(const float* label, data_size_t num_data) { void DCGCalculator::CheckLabel(const label_t* label, data_size_t num_data) {
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
float delta = std::fabs(label[i] - static_cast<int>(label[i])); label_t delta = std::fabs(label[i] - static_cast<int>(label[i]));
if (delta > kEpsilon) { if (delta > kEpsilon) {
Log::Fatal("label should be int type (met %f) for ranking task, \ Log::Fatal("label should be int type (met %f) for ranking task, \
for the gain of label, please set the label_gain parameter.", label[i]); for the gain of label, please set the label_gain parameter.", label[i]);
......
...@@ -75,7 +75,7 @@ public: ...@@ -75,7 +75,7 @@ public:
return 1.0f; return 1.0f;
} }
void CalMapAtK(std::vector<int> ks, data_size_t npos, const float* label, void CalMapAtK(std::vector<int> ks, data_size_t npos, const label_t* label,
const double* score, data_size_t num_data, std::vector<double>* out) const { const double* score, data_size_t num_data, std::vector<double>* out) const {
// get sorted indices by score // get sorted indices by score
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx;
...@@ -149,13 +149,13 @@ private: ...@@ -149,13 +149,13 @@ private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Query boundaries information */ /*! \brief Query boundaries information */
const data_size_t* query_boundaries_; const data_size_t* query_boundaries_;
/*! \brief Number of queries */ /*! \brief Number of queries */
data_size_t num_queries_; data_size_t num_queries_;
/*! \brief Weights of queries */ /*! \brief Weights of queries */
const float* query_weights_; const label_t* query_weights_;
/*! \brief Sum weights of queries */ /*! \brief Sum weights of queries */
double sum_query_weights_; double sum_query_weights_;
/*! \brief Evaluate position of Nmap */ /*! \brief Evaluate position of Nmap */
......
...@@ -118,9 +118,9 @@ private: ...@@ -118,9 +118,9 @@ private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Pointer of weighs */ /*! \brief Pointer of weighs */
const float* weights_; const label_t* weights_;
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of this test set */ /*! \brief Name of this test set */
...@@ -133,7 +133,7 @@ class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> { ...@@ -133,7 +133,7 @@ class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public: public:
explicit MultiErrorMetric(const MetricConfig& config) :MulticlassMetric<MultiErrorMetric>(config) {} explicit MultiErrorMetric(const MetricConfig& config) :MulticlassMetric<MultiErrorMetric>(config) {}
inline static double LossOnPoint(float label, std::vector<double>& score) { inline static double LossOnPoint(label_t label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
for (size_t i = 0; i < score.size(); ++i) { for (size_t i = 0; i < score.size(); ++i) {
if (i != k && score[i] >= score[k]) { if (i != k && score[i] >= score[k]) {
...@@ -153,7 +153,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr ...@@ -153,7 +153,7 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
public: public:
explicit MultiSoftmaxLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {} explicit MultiSoftmaxLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {}
inline static double LossOnPoint(float label, std::vector<double>& score) { inline static double LossOnPoint(label_t label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
if (score[k] > kEpsilon) { if (score[k] > kEpsilon) {
return static_cast<double>(-std::log(score[k])); return static_cast<double>(-std::log(score[k]));
......
...@@ -146,7 +146,7 @@ private: ...@@ -146,7 +146,7 @@ private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Name of test set */ /*! \brief Name of test set */
std::vector<std::string> name_; std::vector<std::string> name_;
/*! \brief Query boundaries information */ /*! \brief Query boundaries information */
...@@ -154,7 +154,7 @@ private: ...@@ -154,7 +154,7 @@ private:
/*! \brief Number of queries */ /*! \brief Number of queries */
data_size_t num_queries_; data_size_t num_queries_;
/*! \brief Weights of queries */ /*! \brief Weights of queries */
const float* query_weights_; const label_t* query_weights_;
/*! \brief Sum weights of queries */ /*! \brief Sum weights of queries */
double sum_query_weights_; double sum_query_weights_;
/*! \brief Evaluate position of NDCG */ /*! \brief Evaluate position of NDCG */
......
...@@ -94,9 +94,9 @@ private: ...@@ -94,9 +94,9 @@ private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Pointer of weighs */ /*! \brief Pointer of weighs */
const float* weights_; const label_t* weights_;
/*! \brief Sum weights */ /*! \brief Sum weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of this test set */ /*! \brief Name of this test set */
...@@ -109,7 +109,7 @@ class RMSEMetric: public RegressionMetric<RMSEMetric> { ...@@ -109,7 +109,7 @@ class RMSEMetric: public RegressionMetric<RMSEMetric> {
public: public:
explicit RMSEMetric(const MetricConfig& config) :RegressionMetric<RMSEMetric>(config) {} explicit RMSEMetric(const MetricConfig& config) :RegressionMetric<RMSEMetric>(config) {}
inline static double LossOnPoint(float label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const MetricConfig&) {
return (score - label)*(score - label); return (score - label)*(score - label);
} }
...@@ -128,7 +128,7 @@ class L2Metric: public RegressionMetric<L2Metric> { ...@@ -128,7 +128,7 @@ class L2Metric: public RegressionMetric<L2Metric> {
public: public:
explicit L2Metric(const MetricConfig& config) :RegressionMetric<L2Metric>(config) {} explicit L2Metric(const MetricConfig& config) :RegressionMetric<L2Metric>(config) {}
inline static double LossOnPoint(float label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const MetricConfig&) {
return (score - label)*(score - label); return (score - label)*(score - label);
} }
...@@ -143,7 +143,7 @@ public: ...@@ -143,7 +143,7 @@ public:
explicit QuantileMetric(const MetricConfig& config) :RegressionMetric<QuantileMetric>(config) { explicit QuantileMetric(const MetricConfig& config) :RegressionMetric<QuantileMetric>(config) {
} }
inline static double LossOnPoint(float label, double score, const MetricConfig& config) { inline static double LossOnPoint(label_t label, double score, const MetricConfig& config) {
double delta = label - score; double delta = label - score;
if (delta < 0) { if (delta < 0) {
return (config.alpha - 1.0f) * delta; return (config.alpha - 1.0f) * delta;
...@@ -163,7 +163,7 @@ class L1Metric: public RegressionMetric<L1Metric> { ...@@ -163,7 +163,7 @@ class L1Metric: public RegressionMetric<L1Metric> {
public: public:
explicit L1Metric(const MetricConfig& config) :RegressionMetric<L1Metric>(config) {} explicit L1Metric(const MetricConfig& config) :RegressionMetric<L1Metric>(config) {}
inline static double LossOnPoint(float label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const MetricConfig&) {
return std::fabs(score - label); return std::fabs(score - label);
} }
inline static const char* Name() { inline static const char* Name() {
...@@ -177,7 +177,7 @@ public: ...@@ -177,7 +177,7 @@ public:
explicit HuberLossMetric(const MetricConfig& config) :RegressionMetric<HuberLossMetric>(config) { explicit HuberLossMetric(const MetricConfig& config) :RegressionMetric<HuberLossMetric>(config) {
} }
inline static double LossOnPoint(float label, double score, const MetricConfig& config) { inline static double LossOnPoint(label_t label, double score, const MetricConfig& config) {
const double diff = score - label; const double diff = score - label;
if (std::abs(diff) <= config.alpha) { if (std::abs(diff) <= config.alpha) {
return 0.5f * diff * diff; return 0.5f * diff * diff;
...@@ -198,7 +198,7 @@ public: ...@@ -198,7 +198,7 @@ public:
explicit FairLossMetric(const MetricConfig& config) :RegressionMetric<FairLossMetric>(config) { explicit FairLossMetric(const MetricConfig& config) :RegressionMetric<FairLossMetric>(config) {
} }
inline static double LossOnPoint(float label, double score, const MetricConfig& config) { inline static double LossOnPoint(label_t label, double score, const MetricConfig& config) {
const double x = std::fabs(score - label); const double x = std::fabs(score - label);
const double c = config.fair_c; const double c = config.fair_c;
return c * x - c * c * std::log(1.0f + x / c); return c * x - c * c * std::log(1.0f + x / c);
...@@ -215,7 +215,7 @@ public: ...@@ -215,7 +215,7 @@ public:
explicit PoissonMetric(const MetricConfig& config) :RegressionMetric<PoissonMetric>(config) { explicit PoissonMetric(const MetricConfig& config) :RegressionMetric<PoissonMetric>(config) {
} }
inline static double LossOnPoint(float label, double score, const MetricConfig&) { inline static double LossOnPoint(label_t label, double score, const MetricConfig&) {
const double eps = 1e-10f; const double eps = 1e-10f;
if (score < eps) { if (score < eps) {
score = eps; score = eps;
......
...@@ -29,7 +29,7 @@ namespace LightGBM { ...@@ -29,7 +29,7 @@ namespace LightGBM {
// label should be in interval [0, 1]; // label should be in interval [0, 1];
// prob should be in interval (0, 1); prob is clipped if needed // prob should be in interval (0, 1); prob is clipped if needed
inline static double XentLoss(float label, double prob) { inline static double XentLoss(label_t label, double prob) {
const double log_arg_epsilon = 1.0e-12; const double log_arg_epsilon = 1.0e-12;
double a = label; double a = label;
if (prob > log_arg_epsilon) { if (prob > log_arg_epsilon) {
...@@ -47,7 +47,7 @@ namespace LightGBM { ...@@ -47,7 +47,7 @@ namespace LightGBM {
} }
// hhat >(=) 0 assumed; and weight > 0 required; but not checked here // hhat >(=) 0 assumed; and weight > 0 required; but not checked here
inline static double XentLambdaLoss(float label, float weight, double hhat) { inline static double XentLambdaLoss(label_t label, label_t weight, double hhat) {
return XentLoss(label, 1.0f - std::exp(-weight * hhat)); return XentLoss(label, 1.0f - std::exp(-weight * hhat));
} }
...@@ -79,15 +79,15 @@ public: ...@@ -79,15 +79,15 @@ public:
CHECK_NOTNULL(label_); CHECK_NOTNULL(label_);
// ensure that labels are in interval [0, 1], interval ends included // ensure that labels are in interval [0, 1], interval ends included
Common::CheckElementsIntervalClosed(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str()); Common::CheckElementsIntervalClosed<label_t>(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str());
Log::Info("[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__); Log::Info("[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__);
// check that weights are non-negative and sum is positive // check that weights are non-negative and sum is positive
if (weights_ == nullptr) { if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_); sum_weights_ = static_cast<double>(num_data_);
} else { } else {
float minw; label_t minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sum_weights_); Common::ObtainMinMaxSum(weights_, num_data_, &minw, (label_t*)nullptr, &sum_weights_);
if (minw < 0.0f) { if (minw < 0.0f) {
Log::Fatal("[%s:%s]: (metric) weights not allowed to be negative", GetName()[0].c_str(), __func__); Log::Fatal("[%s:%s]: (metric) weights not allowed to be negative", GetName()[0].c_str(), __func__);
} }
...@@ -147,9 +147,9 @@ private: ...@@ -147,9 +147,9 @@ private:
/*! \brief Number of data points */ /*! \brief Number of data points */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer to label */ /*! \brief Pointer to label */
const float* label_; const label_t* label_;
/*! \brief Pointer to weights */ /*! \brief Pointer to weights */
const float* weights_; const label_t* weights_;
/*! \brief Sum of weights */ /*! \brief Sum of weights */
double sum_weights_; double sum_weights_;
/*! \brief Name of this metric */ /*! \brief Name of this metric */
...@@ -172,13 +172,13 @@ public: ...@@ -172,13 +172,13 @@ public:
weights_ = metadata.weights(); weights_ = metadata.weights();
CHECK_NOTNULL(label_); CHECK_NOTNULL(label_);
Common::CheckElementsIntervalClosed(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str()); Common::CheckElementsIntervalClosed<label_t>(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str());
Log::Info("[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__); Log::Info("[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__);
// check all weights are strictly positive; throw error if not // check all weights are strictly positive; throw error if not
if (weights_ != nullptr) { if (weights_ != nullptr) {
float minw; label_t minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, (float*)nullptr); Common::ObtainMinMaxSum(weights_, num_data_, &minw, (label_t*)nullptr, (label_t*)nullptr);
if (minw <= 0.0f) { if (minw <= 0.0f) {
Log::Fatal("[%s:%s]: (metric) all weights must be positive", GetName()[0].c_str(), __func__); Log::Fatal("[%s:%s]: (metric) all weights must be positive", GetName()[0].c_str(), __func__);
} }
...@@ -234,9 +234,9 @@ private: ...@@ -234,9 +234,9 @@ private:
/*! \brief Number of data points */ /*! \brief Number of data points */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer to label */ /*! \brief Pointer to label */
const float* label_; const label_t* label_;
/*! \brief Pointer to weights */ /*! \brief Pointer to weights */
const float* weights_; const label_t* weights_;
/*! \brief Name of this metric */ /*! \brief Name of this metric */
std::vector<std::string> name_; std::vector<std::string> name_;
}; };
...@@ -256,14 +256,14 @@ public: ...@@ -256,14 +256,14 @@ public:
weights_ = metadata.weights(); weights_ = metadata.weights();
CHECK_NOTNULL(label_); CHECK_NOTNULL(label_);
Common::CheckElementsIntervalClosed(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str()); Common::CheckElementsIntervalClosed<label_t>(label_, 0.0f, 1.0f, num_data_, GetName()[0].c_str());
Log::Info("[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__); Log::Info("[%s:%s]: (metric) labels passed interval [0, 1] check", GetName()[0].c_str(), __func__);
if (weights_ == nullptr) { if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_); sum_weights_ = static_cast<double>(num_data_);
} else { } else {
float minw; label_t minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sum_weights_); Common::ObtainMinMaxSum(weights_, num_data_, &minw, (label_t*)nullptr, &sum_weights_);
if (minw < 0.0f) { if (minw < 0.0f) {
Log::Fatal("[%s:%s]: (metric) at least one weight is negative", GetName()[0].c_str(), __func__); Log::Fatal("[%s:%s]: (metric) at least one weight is negative", GetName()[0].c_str(), __func__);
} }
...@@ -342,9 +342,9 @@ private: ...@@ -342,9 +342,9 @@ private:
/*! \brief Number of data points */ /*! \brief Number of data points */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer to label */ /*! \brief Pointer to label */
const float* label_; const label_t* label_;
/*! \brief Pointer to weights */ /*! \brief Pointer to weights */
const float* weights_; const label_t* weights_;
/*! \brief Sum of weights */ /*! \brief Sum of weights */
double sum_weights_; double sum_weights_;
/*! \brief Offset term to cross-entropy; precomputed during init */ /*! \brief Offset term to cross-entropy; precomputed during init */
......
...@@ -12,7 +12,7 @@ namespace LightGBM { ...@@ -12,7 +12,7 @@ namespace LightGBM {
*/ */
class BinaryLogloss: public ObjectiveFunction { class BinaryLogloss: public ObjectiveFunction {
public: public:
explicit BinaryLogloss(const ObjectiveConfig& config, std::function<bool(float)> is_pos = nullptr) { explicit BinaryLogloss(const ObjectiveConfig& config, std::function<bool(label_t)> is_pos = nullptr) {
sigmoid_ = static_cast<double>(config.sigmoid); sigmoid_ = static_cast<double>(config.sigmoid);
if (sigmoid_ <= 0.0) { if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_); Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
...@@ -24,7 +24,7 @@ public: ...@@ -24,7 +24,7 @@ public:
} }
is_pos_ = is_pos; is_pos_ = is_pos;
if (is_pos_ == nullptr) { if (is_pos_ == nullptr) {
is_pos_ = [](float label) {return label > 0; }; is_pos_ = [](label_t label) {return label > 0; };
} }
} }
...@@ -138,7 +138,7 @@ private: ...@@ -138,7 +138,7 @@ private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief True if using unbalance training */ /*! \brief True if using unbalance training */
bool is_unbalance_; bool is_unbalance_;
/*! \brief Sigmoid parameter */ /*! \brief Sigmoid parameter */
...@@ -148,9 +148,9 @@ private: ...@@ -148,9 +148,9 @@ private:
/*! \brief Weights for positive and negative labels */ /*! \brief Weights for positive and negative labels */
double label_weights_[2]; double label_weights_[2];
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const label_t* weights_;
double scale_pos_weight_; double scale_pos_weight_;
std::function<bool(float)> is_pos_; std::function<bool(label_t)> is_pos_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -126,11 +126,11 @@ private: ...@@ -126,11 +126,11 @@ private:
/*! \brief Number of classes */ /*! \brief Number of classes */
int num_class_; int num_class_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Corresponding integers of label_ */ /*! \brief Corresponding integers of label_ */
std::vector<int> label_int_; std::vector<int> label_int_;
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const label_t* weights_;
}; };
/*! /*!
...@@ -142,7 +142,7 @@ public: ...@@ -142,7 +142,7 @@ public:
num_class_ = config.num_class; num_class_ = config.num_class;
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
binary_loss_.emplace_back( binary_loss_.emplace_back(
new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; })); new BinaryLogloss(config, [i](label_t label) { return static_cast<int>(label) == i; }));
} }
sigmoid_ = config.sigmoid; sigmoid_ = config.sigmoid;
} }
......
...@@ -89,7 +89,7 @@ public: ...@@ -89,7 +89,7 @@ public:
// get max DCG on current query // get max DCG on current query
const double inverse_max_dcg = inverse_max_dcgs_[query_id]; const double inverse_max_dcg = inverse_max_dcgs_[query_id];
// add pointers with offset // add pointers with offset
const float* label = label_ + start; const label_t* label = label_ + start;
score += start; score += start;
lambdas += start; lambdas += start;
hessians += start; hessians += start;
...@@ -164,8 +164,8 @@ public: ...@@ -164,8 +164,8 @@ public:
// if need weights // if need weights
if (weights_ != nullptr) { if (weights_ != nullptr) {
for (data_size_t i = 0; i < cnt; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
lambdas[i] *= weights_[start + i]; lambdas[i] = static_cast<score_t>(lambdas[i] * weights_[start + i]);
hessians[i] *= weights_[start + i]; hessians[i] = static_cast<score_t>(hessians[i] * weights_[start + i]);
} }
} }
} }
...@@ -224,9 +224,9 @@ private: ...@@ -224,9 +224,9 @@ private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Pointer of weights */ /*! \brief Pointer of weights */
const float* weights_; const label_t* weights_;
/*! \brief Query boundries */ /*! \brief Query boundries */
const data_size_t* query_boundaries_; const data_size_t* query_boundaries_;
/*! \brief Cache result for sigmoid transform to speed up */ /*! \brief Cache result for sigmoid transform to speed up */
......
...@@ -53,8 +53,8 @@ public: ...@@ -53,8 +53,8 @@ public:
} else { } else {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
gradients[i] = static_cast<score_t>(score[i] - label_[i]) * weights_[i]; gradients[i] = static_cast<score_t>((score[i] - label_[i]) * weights_[i]);
hessians[i] = weights_[i]; hessians[i] = static_cast<score_t>(weights_[i]);
} }
} }
} }
...@@ -101,10 +101,10 @@ protected: ...@@ -101,10 +101,10 @@ protected:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const label_t* label_;
/*! \brief Pointer of weights */ /*! \brief Pointer of weights */
const float* weights_; const label_t* weights_;
std::vector<float> trans_label_; std::vector<label_t> trans_label_;
}; };
/*! /*!
...@@ -140,9 +140,9 @@ public: ...@@ -140,9 +140,9 @@ public:
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
const double diff = score[i] - label_[i]; const double diff = score[i] - label_[i];
if (diff >= 0.0f) { if (diff >= 0.0f) {
gradients[i] = weights_[i]; gradients[i] = static_cast<score_t>(weights_[i]);
} else { } else {
gradients[i] = -weights_[i]; gradients[i] = static_cast<score_t>(-weights_[i]);
} }
hessians[i] = static_cast<score_t>(Common::ApproximateHessianWithGaussian(score[i], label_[i], gradients[i], eta_, weights_[i])); hessians[i] = static_cast<score_t>(Common::ApproximateHessianWithGaussian(score[i], label_[i], gradients[i], eta_, weights_[i]));
} }
...@@ -204,7 +204,7 @@ public: ...@@ -204,7 +204,7 @@ public:
if (std::abs(diff) <= alpha_) { if (std::abs(diff) <= alpha_) {
gradients[i] = static_cast<score_t>(diff * weights_[i]); gradients[i] = static_cast<score_t>(diff * weights_[i]);
hessians[i] = weights_[i]; hessians[i] = static_cast<score_t>(weights_[i]);
} else { } else {
if (diff >= 0.0f) { if (diff >= 0.0f) {
gradients[i] = static_cast<score_t>(alpha_ * weights_[i]); gradients[i] = static_cast<score_t>(alpha_ * weights_[i]);
...@@ -297,9 +297,9 @@ public: ...@@ -297,9 +297,9 @@ public:
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
RegressionL2loss::Init(metadata, num_data); RegressionL2loss::Init(metadata, num_data);
// Safety check of labels // Safety check of labels
float miny; label_t miny;
double sumy; double sumy;
Common::ObtainMinMaxSum(label_, num_data_, &miny, (float*)nullptr, &sumy); Common::ObtainMinMaxSum(label_, num_data_, &miny, (label_t*)nullptr, &sumy);
if (miny < 0.0f) { if (miny < 0.0f) {
Log::Fatal("[%s]: at least one target label is negative.", GetName()); Log::Fatal("[%s]: at least one target label is negative.", GetName());
} }
...@@ -405,11 +405,11 @@ public: ...@@ -405,11 +405,11 @@ public:
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
score_t delta = static_cast<score_t>(score[i] - label_[i]); score_t delta = static_cast<score_t>(score[i] - label_[i]);
if (delta >= 0) { if (delta >= 0) {
gradients[i] = (1.0f - alpha_) * weights_[i]; gradients[i] = static_cast<score_t>((1.0f - alpha_) * weights_[i]);
} else { } else {
gradients[i] = -alpha_ * weights_[i]; gradients[i] = static_cast<score_t>(-alpha_ * weights_[i]);
} }
hessians[i] = weights_[i]; hessians[i] = static_cast<score_t>(weights_[i]);
} }
} }
} }
...@@ -454,11 +454,11 @@ public: ...@@ -454,11 +454,11 @@ public:
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
score_t delta = static_cast<score_t>(score[i] - label_[i]); score_t delta = static_cast<score_t>(score[i] - label_[i]);
if (delta > 0) { if (delta > 0) {
gradients[i] = (1.0f - alpha_) * delta * weights_[i]; gradients[i] = static_cast<score_t>((1.0f - alpha_) * delta * weights_[i]);
hessians[i] = (1.0f - alpha_) * weights_[i]; hessians[i] = static_cast<score_t>((1.0f - alpha_) * weights_[i]);
} else { } else {
gradients[i] = alpha_ * delta * weights_[i]; gradients[i] = static_cast<score_t>(alpha_ * delta * weights_[i]);
hessians[i] = alpha_ * weights_[i]; hessians[i] = static_cast<score_t>(alpha_ * weights_[i]);
} }
} }
} }
......
...@@ -52,13 +52,13 @@ public: ...@@ -52,13 +52,13 @@ public:
weights_ = metadata.weights(); weights_ = metadata.weights();
CHECK_NOTNULL(label_); CHECK_NOTNULL(label_);
Common::CheckElementsIntervalClosed(label_, 0.0f, 1.0f, num_data_, GetName()); Common::CheckElementsIntervalClosed<label_t>(label_, 0.0f, 1.0f, num_data_, GetName());
Log::Info("[%s:%s]: (objective) labels passed interval [0, 1] check", GetName(), __func__); Log::Info("[%s:%s]: (objective) labels passed interval [0, 1] check", GetName(), __func__);
if (weights_ != nullptr) { if (weights_ != nullptr) {
float minw; label_t minw;
double sumw; double sumw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sumw); Common::ObtainMinMaxSum(weights_, num_data_, &minw, (label_t*)nullptr, &sumw);
if (minw < 0.0f) { if (minw < 0.0f) {
Log::Fatal("[%s]: at least one weight is negative.", GetName()); Log::Fatal("[%s]: at least one weight is negative.", GetName());
} }
...@@ -133,9 +133,9 @@ private: ...@@ -133,9 +133,9 @@ private:
/*! \brief Number of data points */ /*! \brief Number of data points */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer for label */ /*! \brief Pointer for label */
const float* label_; const label_t* label_;
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const label_t* weights_;
}; };
/*! /*!
...@@ -158,12 +158,12 @@ public: ...@@ -158,12 +158,12 @@ public:
weights_ = metadata.weights(); weights_ = metadata.weights();
CHECK_NOTNULL(label_); CHECK_NOTNULL(label_);
Common::CheckElementsIntervalClosed(label_, 0.0f, 1.0f, num_data_, GetName()); Common::CheckElementsIntervalClosed<label_t>(label_, 0.0f, 1.0f, num_data_, GetName());
Log::Info("[%s:%s]: (objective) labels passed interval [0, 1] check", GetName(), __func__); Log::Info("[%s:%s]: (objective) labels passed interval [0, 1] check", GetName(), __func__);
if (weights_ != nullptr) { if (weights_ != nullptr) {
Common::ObtainMinMaxSum(weights_, num_data_, &min_weight_, &max_weight_, (float*)nullptr); Common::ObtainMinMaxSum(weights_, num_data_, &min_weight_, &max_weight_, (label_t*)nullptr);
if (min_weight_ <= 0.0f) { if (min_weight_ <= 0.0f) {
Log::Fatal("[%s]: at least one weight is non-positive.", GetName()); Log::Fatal("[%s]: at least one weight is non-positive.", GetName());
} }
...@@ -254,13 +254,13 @@ private: ...@@ -254,13 +254,13 @@ private:
/*! \brief Number of data points */ /*! \brief Number of data points */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer for label */ /*! \brief Pointer for label */
const float* label_; const label_t* label_;
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const label_t* weights_;
/*! \brief Minimum weight found during init */ /*! \brief Minimum weight found during init */
float min_weight_; label_t min_weight_;
/*! \brief Maximum weight found during init */ /*! \brief Maximum weight found during init */
float max_weight_; label_t max_weight_;
}; };
} // end namespace LightGBM } // end namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment