Commit bfb0217a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Move all prediction transform to the objective. (#383)

* many refactors.

* remove multi_loglossova.

* fix tests.

* avoid using lambda function.

* fix some format.

* reduce branching.
parent d4c4d9ae
...@@ -36,7 +36,9 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can ...@@ -36,7 +36,9 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can
* validation/test data, LightGBM will output metrics for these data * validation/test data, LightGBM will output metrics for these data
* support multi validation data, separate by ```,``` * support multi validation data, separate by ```,```
* ```num_iterations```, default=```10```, type=int, alias=```num_iteration```,```num_tree```,```num_trees```,```num_round```,```num_rounds``` * ```num_iterations```, default=```10```, type=int, alias=```num_iteration```,```num_tree```,```num_trees```,```num_round```,```num_rounds```
* number of boosting iterations/trees * number of boosting iterations
* note: ```num_tree``` here equal with ```num_iterations```. For multi-class, it actually learns ```num_class * num_iterations``` trees.
* note: For python/R package, cannot use this parameters to control number of iterations.
* ```learning_rate```, default=```0.1```, type=double, alias=```shrinkage_rate``` * ```learning_rate```, default=```0.1```, type=double, alias=```shrinkage_rate```
* shrinkage rate * shrinkage rate
* in ```dart```, it also affects normalization weights of dropped trees * in ```dart```, it also affects normalization weights of dropped trees
......
...@@ -26,13 +26,13 @@ public: ...@@ -26,13 +26,13 @@ public:
* \brief Initialization logic * \brief Initialization logic
* \param config Configs for boosting * \param config Configs for boosting
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param objective_function Training objective function
* \param training_metrics Training metric * \param training_metrics Training metric
*/ */
virtual void Init( virtual void Init(
const BoostingConfig* config, const BoostingConfig* config,
const Dataset* train_data, const Dataset* train_data,
const ObjectiveFunction* object_function, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
/*! /*!
...@@ -46,10 +46,10 @@ public: ...@@ -46,10 +46,10 @@ public:
* \brief Reset training data for current boosting * \brief Reset training data for current boosting
* \param config Configs for boosting * \param config Configs for boosting
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param objective_function Training objective function
* \param training_metrics Training metric * \param training_metrics Training metric
*/ */
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0; virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector<const Metric*>& training_metrics) = 0;
/*! /*!
* \brief Add a validation data * \brief Add a validation data
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
#include <LightGBM/config.h> #include <LightGBM/config.h>
#include <LightGBM/dataset.h> #include <LightGBM/dataset.h>
#include <LightGBM/objective_function.h>
#include <vector> #include <vector>
...@@ -33,7 +34,8 @@ public: ...@@ -33,7 +34,8 @@ public:
* \brief Calcaluting and printing metric result * \brief Calcaluting and printing metric result
* \param score Current prediction score * \param score Current prediction score
*/ */
virtual std::vector<double> Eval(const double* score) const = 0; virtual std::vector<double> Eval(const double* score, const ObjectiveFunction* objective,
int num_tree_per_iteration) const = 0;
Metric() = default; Metric() = default;
/*! \brief Disable copy */ /*! \brief Disable copy */
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
#include <LightGBM/config.h> #include <LightGBM/config.h>
#include <LightGBM/dataset.h> #include <LightGBM/dataset.h>
#include <functional>
namespace LightGBM { namespace LightGBM {
/*! /*!
* \brief The interface of Objective Function. * \brief The interface of Objective Function.
*/ */
...@@ -35,6 +35,22 @@ public: ...@@ -35,6 +35,22 @@ public:
virtual bool IsConstantHessian() const { return false; } virtual bool IsConstantHessian() const { return false; }
virtual bool BoostFromAverage() const { return false; }
virtual bool SkipEmptyClass() const { return false; }
virtual int numTreePerIteration() const { return 1; }
virtual std::vector<double> ConvertOutput(std::vector<double>& input) const {
return input;
}
virtual double ConvertOutput(double input) const {
return input;
}
virtual std::string ToString() const = 0;
ObjectiveFunction() = default; ObjectiveFunction() = default;
/*! \brief Disable copy */ /*! \brief Disable copy */
ObjectiveFunction& operator=(const ObjectiveFunction&) = delete; ObjectiveFunction& operator=(const ObjectiveFunction&) = delete;
...@@ -48,6 +64,11 @@ public: ...@@ -48,6 +64,11 @@ public:
*/ */
LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& type, LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& type,
const ObjectiveConfig& config); const ObjectiveConfig& config);
/*!
* \brief Load objective function from string object
*/
LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& str);
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -377,6 +377,21 @@ inline void Softmax(std::vector<double>* p_rec) { ...@@ -377,6 +377,21 @@ inline void Softmax(std::vector<double>* p_rec) {
} }
} }
inline void Softmax(double* rec, int len) {
double wmax = rec[0];
for (int i = 1; i < len; ++i) {
wmax = std::max(rec[i], wmax);
}
double wsum = 0.0f;
for (int i = 0; i < len; ++i) {
rec[i] = std::exp(rec[i] - wmax);
wsum += rec[i];
}
for (int i = 0; i < len; ++i) {
rec[i] /= static_cast<double>(wsum);
}
}
template<typename T> template<typename T>
std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<T>>& input) { std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<T>>& input) {
std::vector<const T*> ret; std::vector<const T*> ret;
......
...@@ -28,20 +28,20 @@ public: ...@@ -28,20 +28,20 @@ public:
* \brief Initialization logic * \brief Initialization logic
* \param config Config for boosting * \param config Config for boosting
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param objective_function Training objective function
* \param training_metrics Training metrics * \param training_metrics Training metrics
* \param output_model_filename Filename of output model * \param output_model_filename Filename of output model
*/ */
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, object_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
random_for_drop_ = Random(gbdt_config_->drop_seed); random_for_drop_ = Random(gbdt_config_->drop_seed);
sum_weight_ = 0.0f; sum_weight_ = 0.0f;
} }
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics); GBDT::ResetTrainingData(config, train_data, objective_function, training_metrics);
} }
/*! /*!
* \brief one training iteration * \brief one training iteration
...@@ -110,10 +110,10 @@ private: ...@@ -110,10 +110,10 @@ private:
} }
// drop trees // drop trees
for (auto i : drop_index_) { for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_class_ + curr_class; auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
models_[curr_tree]->Shrinkage(-1.0); models_[curr_tree]->Shrinkage(-1.0);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
} }
if (!gbdt_config_->xgboost_dart_mode) { if (!gbdt_config_->xgboost_dart_mode) {
...@@ -140,16 +140,16 @@ private: ...@@ -140,16 +140,16 @@ private:
double k = static_cast<double>(drop_index_.size()); double k = static_cast<double>(drop_index_.size());
if (!gbdt_config_->xgboost_dart_mode) { if (!gbdt_config_->xgboost_dart_mode) {
for (auto i : drop_index_) { for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_class_ + curr_class; auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
// update validation score // update validation score
models_[curr_tree]->Shrinkage(1.0f / (k + 1.0f)); models_[curr_tree]->Shrinkage(1.0f / (k + 1.0f));
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree].get(), curr_class); score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
// update training score // update training score
models_[curr_tree]->Shrinkage(-k); models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
if (!gbdt_config_->uniform_drop) { if (!gbdt_config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f)); sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
...@@ -158,16 +158,16 @@ private: ...@@ -158,16 +158,16 @@ private:
} }
} else { } else {
for (auto i : drop_index_) { for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_class_ + curr_class; auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
// update validation score // update validation score
models_[curr_tree]->Shrinkage(shrinkage_rate_); models_[curr_tree]->Shrinkage(shrinkage_rate_);
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree].get(), curr_class); score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
// update training score // update training score
models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate); models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class); train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
} }
if (!gbdt_config_->uniform_drop) { if (!gbdt_config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));; sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));;
......
This diff is collapsed.
...@@ -28,10 +28,10 @@ public: ...@@ -28,10 +28,10 @@ public:
* \brief Initialization logic * \brief Initialization logic
* \param gbdt_config Config for boosting * \param gbdt_config Config for boosting
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param objective_function Training objective function
* \param training_metrics Training metrics * \param training_metrics Training metrics
*/ */
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* object_function, void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) const std::vector<const Metric*>& training_metrics)
override; override;
...@@ -62,10 +62,10 @@ public: ...@@ -62,10 +62,10 @@ public:
/*! /*!
* \brief Reset training data for current boosting * \brief Reset training data for current boosting
* \param train_data Training data * \param train_data Training data
* \param object_function Training objective function * \param objective_function Training objective function
* \param training_metrics Training metric * \param training_metrics Training metric
*/ */
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override; void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector<const Metric*>& training_metrics) override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
...@@ -155,14 +155,14 @@ public: ...@@ -155,14 +155,14 @@ public:
* \param is_finish Is training finished or not * \param is_finish Is training finished or not
* \param filename Filename that want to save to * \param filename Filename that want to save to
*/ */
virtual bool SaveModelToFile(int num_iterations, const char* filename) const override ; virtual bool SaveModelToFile(int num_iterations, const char* filename) const override;
/*! /*!
* \brief Save model to string * \brief Save model to string
* \param num_used_model Number of model that want to save, -1 means save all * \param num_used_model Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded * \return Non-empty string if succeeded
*/ */
virtual std::string SaveModelToString(int num_iterations) const override ; virtual std::string SaveModelToString(int num_iterations) const override;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
...@@ -245,9 +245,9 @@ protected: ...@@ -245,9 +245,9 @@ protected:
* \brief updating score for out-of-bag data. * \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training * Data should be update since we may re-bagging data on training
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training * \param cur_tree_id Current tree for multiclass training
*/ */
void UpdateScoreOutOfBag(const Tree* tree, const int curr_class); void UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id);
/*! /*!
* \brief calculate the object function * \brief calculate the object function
*/ */
...@@ -255,9 +255,9 @@ protected: ...@@ -255,9 +255,9 @@ protected:
/*! /*!
* \brief updating score after tree was trained * \brief updating score after tree was trained
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training * \param cur_tree_id Current tree for multiclass training
*/ */
virtual void UpdateScore(const Tree* tree, const int curr_class); virtual void UpdateScore(const Tree* tree, const int cur_tree_id);
/*! /*!
* \brief Print metric result of current iteration * \brief Print metric result of current iteration
* \param iter Current interation * \param iter Current interation
...@@ -277,7 +277,7 @@ protected: ...@@ -277,7 +277,7 @@ protected:
/*! \brief Tree learner, will use this class to learn trees */ /*! \brief Tree learner, will use this class to learn trees */
std::unique_ptr<TreeLearner> tree_learner_; std::unique_ptr<TreeLearner> tree_learner_;
/*! \brief Objective function */ /*! \brief Objective function */
const ObjectiveFunction* object_function_; const ObjectiveFunction* objective_function_;
/*! \brief Store and update training data's score */ /*! \brief Store and update training data's score */
std::unique_ptr<ScoreUpdater> train_score_updater_; std::unique_ptr<ScoreUpdater> train_score_updater_;
/*! \brief Metrics for training data */ /*! \brief Metrics for training data */
...@@ -310,13 +310,10 @@ protected: ...@@ -310,13 +310,10 @@ protected:
std::vector<data_size_t> tmp_indices_; std::vector<data_size_t> tmp_indices_;
/*! \brief Number of training data */ /*! \brief Number of training data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */ /*! \brief Number of trees per iterations */
int num_tree_per_iteration_;
/*! \brief Number of class */
int num_class_; int num_class_;
/*!
* \brief Sigmoid parameter, used for prediction.
* if > 0 means output score will transform by sigmoid function
*/
double sigmoid_;
/*! \brief Index of label column */ /*! \brief Index of label column */
data_size_t label_idx_; data_size_t label_idx_;
/*! \brief number of used model */ /*! \brief number of used model */
...@@ -346,6 +343,7 @@ protected: ...@@ -346,6 +343,7 @@ protected:
std::vector<bool> class_need_train_; std::vector<bool> class_need_train_;
std::vector<double> class_default_output_; std::vector<double> class_default_output_;
bool is_constant_hessian_; bool is_constant_hessian_;
std::unique_ptr<ObjectiveFunction> loaded_objective_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -32,15 +32,15 @@ public: ...@@ -32,15 +32,15 @@ public:
} }
~GOSS() { ~GOSS() {
#ifdef TIMETAG #ifdef TIMETAG
Log::Info("GOSS::subset costs %f", subset_time * 1e-3); Log::Info("GOSS::subset costs %f", subset_time * 1e-3);
Log::Info("GOSS::re_init_tree costs %f", re_init_tree_time * 1e-3); Log::Info("GOSS::re_init_tree costs %f", re_init_tree_time * 1e-3);
#endif #endif
} }
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, object_function, training_metrics); GBDT::Init(config, train_data, objective_function, training_metrics);
CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f); CHECK(gbdt_config_->top_rate + gbdt_config_->other_rate <= 1.0f);
CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f); CHECK(gbdt_config_->top_rate > 0.0f && gbdt_config_->other_rate > 0.0f);
if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) { if (gbdt_config_->bagging_freq > 0 && gbdt_config_->bagging_fraction != 1.0f) {
...@@ -49,12 +49,12 @@ public: ...@@ -49,12 +49,12 @@ public:
Log::Info("using GOSS"); Log::Info("using GOSS");
} }
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
if (config->bagging_freq > 0 && config->bagging_fraction != 1.0f) { if (config->bagging_freq > 0 && config->bagging_fraction != 1.0f) {
Log::Fatal("cannot use bagging in GOSS"); Log::Fatal("cannot use bagging in GOSS");
} }
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics); GBDT::ResetTrainingData(config, train_data, objective_function, training_metrics);
if (train_data_ == nullptr) { return; } if (train_data_ == nullptr) { return; }
bag_data_indices_.resize(num_data_); bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_); tmp_indices_.resize(num_data_);
...@@ -79,8 +79,8 @@ public: ...@@ -79,8 +79,8 @@ public:
data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer, data_size_t* buffer_right) { data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer, data_size_t* buffer_right) {
std::vector<score_t> tmp_gradients(cnt, 0.0f); std::vector<score_t> tmp_gradients(cnt, 0.0f);
for (data_size_t i = 0; i < cnt; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
int idx = curr_class * num_data_ + start + i; int idx = cur_tree_id * num_data_ + start + i;
tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]); tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]);
} }
} }
...@@ -96,8 +96,8 @@ public: ...@@ -96,8 +96,8 @@ public:
data_size_t big_weight_cnt = 0; data_size_t big_weight_cnt = 0;
for (data_size_t i = 0; i < cnt; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
score_t grad = 0.0f; score_t grad = 0.0f;
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
int idx = curr_class * num_data_ + start + i; int idx = cur_tree_id * num_data_ + start + i;
grad += std::fabs(gradients_[idx] * hessians_[idx]); grad += std::fabs(gradients_[idx] * hessians_[idx]);
} }
if (grad >= threshold) { if (grad >= threshold) {
...@@ -110,8 +110,8 @@ public: ...@@ -110,8 +110,8 @@ public:
double prob = (rest_need) / static_cast<double>(rest_all); double prob = (rest_need) / static_cast<double>(rest_all);
if (cur_rand.NextFloat() < prob) { if (cur_rand.NextFloat() < prob) {
buffer[cur_left_cnt++] = start + i; buffer[cur_left_cnt++] = start + i;
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
int idx = curr_class * num_data_ + start + i; int idx = cur_tree_id * num_data_ + start + i;
gradients_[idx] *= multiply; gradients_[idx] *= multiply;
hessians_[idx] *= multiply; hessians_[idx] *= multiply;
} }
...@@ -132,7 +132,7 @@ public: ...@@ -132,7 +132,7 @@ public:
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_; data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
if (inner_size < min_inner_size) { inner_size = min_inner_size; } if (inner_size < min_inner_size) { inner_size = min_inner_size; }
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static, 1) #pragma omp parallel for schedule(static, 1)
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
left_cnts_buf_[i] = 0; left_cnts_buf_[i] = 0;
...@@ -159,7 +159,7 @@ public: ...@@ -159,7 +159,7 @@ public:
} }
left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1]; left_cnt = left_write_pos_buf_[num_threads_ - 1] + left_cnts_buf_[num_threads_ - 1];
#pragma omp parallel for schedule(static, 1) #pragma omp parallel for schedule(static, 1)
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
if (left_cnts_buf_[i] > 0) { if (left_cnts_buf_[i] > 0) {
...@@ -179,21 +179,21 @@ public: ...@@ -179,21 +179,21 @@ public:
tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_); tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
} else { } else {
// get subset // get subset
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
tmp_subset_->ReSize(bag_data_cnt_); tmp_subset_->ReSize(bag_data_cnt_);
tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false); tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
#ifdef TIMETAG #ifdef TIMETAG
subset_time += std::chrono::steady_clock::now() - start_time; subset_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
#endif #endif
tree_learner_->ResetTrainingData(tmp_subset_.get()); tree_learner_->ResetTrainingData(tmp_subset_.get());
#ifdef TIMETAG #ifdef TIMETAG
re_init_tree_time += std::chrono::steady_clock::now() - start_time; re_init_tree_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
} }
} }
......
...@@ -20,9 +20,9 @@ public: ...@@ -20,9 +20,9 @@ public:
* \brief Constructor, will pass a const pointer of dataset * \brief Constructor, will pass a const pointer of dataset
* \param data This class will bind with this data set * \param data This class will bind with this data set
*/ */
ScoreUpdater(const Dataset* data, int num_class) : data_(data) { ScoreUpdater(const Dataset* data, int num_tree_per_iteration) : data_(data) {
num_data_ = data->num_data(); num_data_ = data->num_data();
int64_t total_size = static_cast<int64_t>(num_data_) * num_class; int64_t total_size = static_cast<int64_t>(num_data_) * num_tree_per_iteration;
score_.resize(total_size); score_.resize(total_size);
// default start score is zero // default start score is zero
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -34,7 +34,7 @@ public: ...@@ -34,7 +34,7 @@ public:
// if exists initial score, will start from it // if exists initial score, will start from it
if (init_score != nullptr) { if (init_score != nullptr) {
if ((data->metadata().num_init_score() % num_data_) != 0 if ((data->metadata().num_init_score() % num_data_) != 0
|| (data->metadata().num_init_score() / num_data_) != num_class) { || (data->metadata().num_init_score() / num_data_) != num_tree_per_iteration) {
Log::Fatal("number of class for initial score error"); Log::Fatal("number of class for initial score error");
} }
has_init_score_ = true; has_init_score_ = true;
...@@ -51,8 +51,8 @@ public: ...@@ -51,8 +51,8 @@ public:
inline bool has_init_score() const { return has_init_score_; } inline bool has_init_score() const { return has_init_score_; }
inline void AddScore(double val, int curr_class) { inline void AddScore(double val, int cur_tree_id) {
int64_t offset = curr_class * num_data_; int64_t offset = cur_tree_id * num_data_;
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int64_t i = 0; i < num_data_; ++i) { for (int64_t i = 0; i < num_data_; ++i) {
score_[offset + i] += val; score_[offset + i] += val;
...@@ -62,20 +62,20 @@ public: ...@@ -62,20 +62,20 @@ public:
* \brief Using tree model to get prediction number, then adding to scores for all data * \brief Using tree model to get prediction number, then adding to scores for all data
* Note: this function generally will be used on validation data too. * Note: this function generally will be used on validation data too.
* \param tree Trained tree model * \param tree Trained tree model
* \param curr_class Current class for multiclass training * \param cur_tree_id Current tree for multiclass training
*/ */
inline void AddScore(const Tree* tree, int curr_class) { inline void AddScore(const Tree* tree, int cur_tree_id) {
tree->AddPredictionToScore(data_, num_data_, score_.data() + curr_class * num_data_); tree->AddPredictionToScore(data_, num_data_, score_.data() + cur_tree_id * num_data_);
} }
/*! /*!
* \brief Adding prediction score, only used for training data. * \brief Adding prediction score, only used for training data.
* The training data is partitioned into tree leaves after training * The training data is partitioned into tree leaves after training
* Based on which We can get prediction quickly. * Based on which We can get prediction quickly.
* \param tree_learner * \param tree_learner
* \param curr_class Current class for multiclass training * \param cur_tree_id Current tree for multiclass training
*/ */
inline void AddScore(const TreeLearner* tree_learner, const Tree* tree, int curr_class) { inline void AddScore(const TreeLearner* tree_learner, const Tree* tree, int cur_tree_id) {
tree_learner->AddPredictionToScore(tree, score_.data() + curr_class * num_data_); tree_learner->AddPredictionToScore(tree, score_.data() + cur_tree_id * num_data_);
} }
/*! /*!
* \brief Using tree model to get prediction number, then adding to scores for parts of data * \brief Using tree model to get prediction number, then adding to scores for parts of data
...@@ -83,11 +83,11 @@ public: ...@@ -83,11 +83,11 @@ public:
* \param tree Trained tree model * \param tree Trained tree model
* \param data_indices Indices of data that will be processed * \param data_indices Indices of data that will be processed
* \param data_cnt Number of data that will be processed * \param data_cnt Number of data that will be processed
* \param curr_class Current class for multiclass training * \param cur_tree_id Current tree for multiclass training
*/ */
inline void AddScore(const Tree* tree, const data_size_t* data_indices, inline void AddScore(const Tree* tree, const data_size_t* data_indices,
data_size_t data_cnt, int curr_class) { data_size_t data_cnt, int cur_tree_id) {
tree->AddPredictionToScore(data_, data_indices, data_cnt, score_.data() + curr_class * num_data_); tree->AddPredictionToScore(data_, data_indices, data_cnt, score_.data() + cur_tree_id * num_data_);
} }
/*! \brief Pointer of score */ /*! \brief Pointer of score */
inline const double* score() const { return score_.data(); } inline const double* score() const { return score_.data(); }
......
...@@ -153,18 +153,11 @@ void OverallConfig::CheckParamConflict() { ...@@ -153,18 +153,11 @@ void OverallConfig::CheckParamConflict() {
if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) { if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
for (std::string metric_type : metric_types) { for (std::string metric_type : metric_types) {
bool metric_type_multiclass = (metric_type == std::string("multi_logloss") bool metric_type_multiclass = (metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error") || metric_type == std::string("multi_error"));
|| metric_type == std::string("multi_loglossova"));
if ((objective_type_multiclass && !metric_type_multiclass) if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) { || (!objective_type_multiclass && metric_type_multiclass)) {
Log::Fatal("Objective and metrics don't match"); Log::Fatal("Objective and metrics don't match");
} }
if (objective_type == std::string("multiclassova") && metric_type == std::string("multi_logloss")) {
Log::Fatal("Wrong metric. For Multi-class with OVA, you should use multi_loglossova metric.");
}
if (objective_type == std::string("multiclass") && metric_type == std::string("multi_loglossova")) {
Log::Fatal("Wrong metric. For Multi-class with softmax, you should use multi_logloss metric.");
}
} }
} }
......
...@@ -18,11 +18,8 @@ namespace LightGBM { ...@@ -18,11 +18,8 @@ namespace LightGBM {
template<typename PointWiseLossCalculator> template<typename PointWiseLossCalculator>
class BinaryMetric: public Metric { class BinaryMetric: public Metric {
public: public:
explicit BinaryMetric(const MetricConfig& config) { explicit BinaryMetric(const MetricConfig&) {
sigmoid_ = static_cast<double>(config.sigmoid);
if (sigmoid_ <= 0.0f) {
Log::Fatal("Sigmoid parameter %f should greater than zero", sigmoid_);
}
} }
virtual ~BinaryMetric() { virtual ~BinaryMetric() {
...@@ -57,13 +54,28 @@ public: ...@@ -57,13 +54,28 @@ public:
return -1.0f; return -1.0f;
} }
std::vector<double> Eval(const double* score) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction* objective,
int) const override {
double sum_loss = 0.0f; double sum_loss = 0.0f;
if (objective == nullptr) {
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform // add loss
double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[i])); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i]) * weights_[i];
}
}
} else {
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
double prob = objective->ConvertOutput(score[i]);
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob);
} }
...@@ -71,11 +83,12 @@ public: ...@@ -71,11 +83,12 @@ public:
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform // sigmoid transform
double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[i])); double prob = objective->ConvertOutput(score[i]);
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i];
} }
} }
}
double loss = sum_loss / sum_weights_; double loss = sum_loss / sum_weights_;
return std::vector<double>(1, loss); return std::vector<double>(1, loss);
} }
...@@ -91,8 +104,6 @@ private: ...@@ -91,8 +104,6 @@ private:
double sum_weights_; double sum_weights_;
/*! \brief Name of test set */ /*! \brief Name of test set */
std::vector<std::string> name_; std::vector<std::string> name_;
/*! \brief Sigmoid parameter */
double sigmoid_;
}; };
/*! /*!
...@@ -178,7 +189,8 @@ public: ...@@ -178,7 +189,8 @@ public:
} }
} }
std::vector<double> Eval(const double* score) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction*,
int) const override {
// get indices sorted by score, descent order // get indices sorted by score, descent order
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx;
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
......
...@@ -21,8 +21,8 @@ public: ...@@ -21,8 +21,8 @@ public:
eval_at_.push_back(static_cast<data_size_t>(k)); eval_at_.push_back(static_cast<data_size_t>(k));
} }
// get number of threads // get number of threads
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
...@@ -93,7 +93,8 @@ public: ...@@ -93,7 +93,8 @@ public:
cur_left = cur_k; cur_left = cur_k;
} }
} }
std::vector<double> Eval(const double* score) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction*,
int) const override {
// some buffers for multi-threading sum up // some buffers for multi-threading sum up
std::vector<std::vector<double>> result_buffer_; std::vector<std::vector<double>> result_buffer_;
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
...@@ -101,7 +102,7 @@ public: ...@@ -101,7 +102,7 @@ public:
} }
std::vector<double> tmp_map(eval_at_.size(), 0.0f); std::vector<double> tmp_map(eval_at_.size(), 0.0f);
if (query_weights_ == nullptr) { if (query_weights_ == nullptr) {
#pragma omp parallel for schedule(guided) firstprivate(tmp_map) #pragma omp parallel for schedule(guided) firstprivate(tmp_map)
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
CalMapAtK(eval_at_, label_ + query_boundaries_[i], CalMapAtK(eval_at_, label_ + query_boundaries_[i],
...@@ -111,7 +112,7 @@ public: ...@@ -111,7 +112,7 @@ public:
} }
} }
} else { } else {
#pragma omp parallel for schedule(guided) firstprivate(tmp_map) #pragma omp parallel for schedule(guided) firstprivate(tmp_map)
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
CalMapAtK(eval_at_, label_ + query_boundaries_[i], CalMapAtK(eval_at_, label_ + query_boundaries_[i],
......
...@@ -30,8 +30,6 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config ...@@ -30,8 +30,6 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new MapMetric(config); return new MapMetric(config);
} else if (type == std::string("multi_logloss")) { } else if (type == std::string("multi_logloss")) {
return new MultiSoftmaxLoglossMetric(config); return new MultiSoftmaxLoglossMetric(config);
} else if (type == std::string("multi_loglossova")) {
return new MultiOVALoglossMetric(config);
} else if (type == std::string("multi_error")) { } else if (type == std::string("multi_error")) {
return new MultiErrorMetric(config); return new MultiErrorMetric(config);
} }
......
...@@ -15,8 +15,8 @@ namespace LightGBM { ...@@ -15,8 +15,8 @@ namespace LightGBM {
template<typename PointWiseLossCalculator> template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric { class MulticlassMetric: public Metric {
public: public:
explicit MulticlassMetric(const MetricConfig& config) { explicit MulticlassMetric(const MetricConfig&) {
num_class_ = config.num_class;
} }
virtual ~MulticlassMetric() { virtual ~MulticlassMetric() {
...@@ -49,13 +49,41 @@ public: ...@@ -49,13 +49,41 @@ public:
return -1.0f; return -1.0f;
} }
std::vector<double> Eval(const double* score) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction* objective,
int num_tree_per_iteration) const override {
double sum_loss = 0.0; double sum_loss = 0.0;
if (objective != nullptr) {
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<double> rec(num_tree_per_iteration);
for (int k = 0; k < num_tree_per_iteration; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]);
}
rec = objective->ConvertOutput(rec);
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<double> rec(num_tree_per_iteration);
for (int k = 0; k < num_tree_per_iteration; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]);
}
rec = objective->ConvertOutput(rec);
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
}
}
} else {
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<double> rec(num_class_); std::vector<double> rec(num_tree_per_iteration);
for (int k = 0; k < num_class_; ++k) { for (int k = 0; k < num_tree_per_iteration; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]); rec[k] = static_cast<double>(score[idx]);
} }
...@@ -65,8 +93,8 @@ public: ...@@ -65,8 +93,8 @@ public:
} else { } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<double> rec(num_class_); std::vector<double> rec(num_tree_per_iteration);
for (int k = 0; k < num_class_; ++k) { for (int k = 0; k < num_tree_per_iteration; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]); rec[k] = static_cast<double>(score[idx]);
} }
...@@ -74,6 +102,7 @@ public: ...@@ -74,6 +102,7 @@ public:
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
} }
} }
}
double loss = sum_loss / sum_weights_; double loss = sum_loss / sum_weights_;
return std::vector<double>(1, loss); return std::vector<double>(1, loss);
} }
...@@ -81,8 +110,6 @@ public: ...@@ -81,8 +110,6 @@ public:
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */ /*! \brief Pointer of label */
const float* label_; const float* label_;
/*! \brief Pointer of weighs */ /*! \brief Pointer of weighs */
...@@ -100,7 +127,7 @@ public: ...@@ -100,7 +127,7 @@ public:
inline static double LossOnPoint(float label, std::vector<double>& score) { inline static double LossOnPoint(float label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
for (size_t i = 0; i < score.size(); ++i){ for (size_t i = 0; i < score.size(); ++i) {
if (i != k && score[i] >= score[k]) { if (i != k && score[i] >= score[k]) {
return 1.0f; return 1.0f;
} }
...@@ -120,7 +147,6 @@ public: ...@@ -120,7 +147,6 @@ public:
inline static double LossOnPoint(float label, std::vector<double>& score) { inline static double LossOnPoint(float label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
Common::Softmax(&score);
if (score[k] > kEpsilon) { if (score[k] > kEpsilon) {
return static_cast<double>(-std::log(score[k])); return static_cast<double>(-std::log(score[k]));
} else { } else {
...@@ -133,83 +159,5 @@ public: ...@@ -133,83 +159,5 @@ public:
} }
}; };
class MultiOVALoglossMetric: public Metric {
public:
explicit MultiOVALoglossMetric(const MetricConfig& config) {
num_class_ = config.num_class;
sigmoid_ = config.sigmoid;
}
virtual ~MultiOVALoglossMetric() {
}
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back("multi_loglossova");
num_data_ = num_data;
// get label
label_ = metadata.label();
// get weights
weights_ = metadata.weights();
if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_);
} else {
sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data_; ++i) {
sum_weights_ += weights_[i];
}
}
}
const std::vector<std::string>& GetName() const override {
return name_;
}
double factor_to_bigger_better() const override {
return -1.0f;
}
std::vector<double> Eval(const double* score) const override {
double sum_loss = 0.0;
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
size_t idx = static_cast<size_t>(num_data_) * static_cast<int>(label_[i]) + i;
double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[idx]));
if (prob < kEpsilon) { prob = kEpsilon; }
// add loss
sum_loss += -std::log(prob);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
size_t idx = static_cast<size_t>(num_data_) * static_cast<int>(label_[i]) + i;
double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[idx]));
if (prob < kEpsilon) { prob = kEpsilon; }
// add loss
sum_loss += -std::log(prob) * weights_[i];
}
}
double loss = sum_loss / sum_weights_;
return std::vector<double>(1, loss);
}
private:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */
const float* label_;
/*! \brief Pointer of weighs */
const float* weights_;
/*! \brief Sum weights */
double sum_weights_;
/*! \brief Name of this test set */
std::vector<std::string> name_;
double sigmoid_;
};
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_ #endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
...@@ -24,8 +24,8 @@ public: ...@@ -24,8 +24,8 @@ public:
// initialize DCG calculator // initialize DCG calculator
DCGCalculator::Init(config.label_gain); DCGCalculator::Init(config.label_gain);
// get number of threads // get number of threads
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
...@@ -82,7 +82,7 @@ public: ...@@ -82,7 +82,7 @@ public:
return 1.0f; return 1.0f;
} }
std::vector<double> Eval(const double* score) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction*, int) const override {
// some buffers for multi-threading sum up // some buffers for multi-threading sum up
std::vector<std::vector<double>> result_buffer_; std::vector<std::vector<double>> result_buffer_;
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
...@@ -90,7 +90,7 @@ public: ...@@ -90,7 +90,7 @@ public:
} }
std::vector<double> tmp_dcg(eval_at_.size(), 0.0f); std::vector<double> tmp_dcg(eval_at_.size(), 0.0f);
if (query_weights_ == nullptr) { if (query_weights_ == nullptr) {
#pragma omp parallel for schedule(static) firstprivate(tmp_dcg) #pragma omp parallel for schedule(static) firstprivate(tmp_dcg)
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
// if all doc in this query are all negative, let its NDCG=1 // if all doc in this query are all negative, let its NDCG=1
...@@ -110,7 +110,7 @@ public: ...@@ -110,7 +110,7 @@ public:
} }
} }
} else { } else {
#pragma omp parallel for schedule(static) firstprivate(tmp_dcg) #pragma omp parallel for schedule(static) firstprivate(tmp_dcg)
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
// if all doc in this query are all negative, let its NDCG=1 // if all doc in this query are all negative, let its NDCG=1
......
...@@ -48,21 +48,37 @@ public: ...@@ -48,21 +48,37 @@ public:
} }
} }
std::vector<double> Eval(const double* score) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction* objective, int) const override {
double sum_loss = 0.0f; double sum_loss = 0.0f;
if (objective == nullptr) {
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i], huber_delta_, fair_c_); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i], huber_delta_, fair_c_);
} }
} else { } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i], huber_delta_, fair_c_) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], score[i], huber_delta_, fair_c_) * weights_[i];
} }
} }
} else {
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], objective->ConvertOutput(score[i]), huber_delta_, fair_c_);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], objective->ConvertOutput(score[i]), huber_delta_, fair_c_) * weights_[i];
}
}
}
double loss = PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_); double loss = PointWiseLossCalculator::AverageLoss(sum_loss, sum_weights_);
return std::vector<double>(1, loss); return std::vector<double>(1, loss);
......
...@@ -25,6 +25,21 @@ public: ...@@ -25,6 +25,21 @@ public:
} }
} }
explicit BinaryLogloss(const std::vector<std::string>& strs) {
sigmoid_ = -1;
for (auto str : strs) {
auto tokens = Common::Split(str.c_str(), ":");
if (tokens.size() == 2) {
if (tokens[0] == std::string("sigmoid")) {
Common::Atof(tokens[1].c_str(), &sigmoid_);
}
}
}
if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
}
}
~BinaryLogloss() {} ~BinaryLogloss() {}
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
...@@ -101,6 +116,24 @@ public: ...@@ -101,6 +116,24 @@ public:
return "binary"; return "binary";
} }
std::vector<double> ConvertOutput(std::vector<double>& input) const override {
input[0] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[0]));
return input;
}
double ConvertOutput(double input) const override {
return 1.0f / (1.0f + std::exp(-sigmoid_ * input));
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName() << " ";
str_buf << "sigmoid:" << sigmoid_;
return str_buf.str();
}
bool SkipEmptyClass() const override { return true; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
......
...@@ -20,6 +20,21 @@ public: ...@@ -20,6 +20,21 @@ public:
softmax_weight_decay_ = 1e-3; softmax_weight_decay_ = 1e-3;
} }
explicit MulticlassSoftmax(const std::vector<std::string>& strs) {
num_class_ = -1;
for (auto str : strs) {
auto tokens = Common::Split(str.c_str(), ":");
if (tokens.size() == 2) {
if (tokens[0] == std::string("num_class")) {
Common::Atoi(tokens[1].c_str(), &num_class_);
}
}
}
if (num_class_ < 0) {
Log::Fatal("Objective should contains num_class field");
}
}
~MulticlassSoftmax() { ~MulticlassSoftmax() {
} }
...@@ -98,10 +113,26 @@ public: ...@@ -98,10 +113,26 @@ public:
} }
} }
std::vector<double> ConvertOutput(std::vector<double>& input) const override {
Common::Softmax(input.data(), num_class_);
return input;
}
const char* GetName() const override { const char* GetName() const override {
return "multiclass"; return "multiclass";
} }
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName() << " ";
str_buf << "num_class:" << num_class_;
return str_buf.str();
}
bool SkipEmptyClass() const override { return true; }
int numTreePerIteration() const override { return num_class_; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
...@@ -129,6 +160,28 @@ public: ...@@ -129,6 +160,28 @@ public:
binary_loss_.emplace_back( binary_loss_.emplace_back(
new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; })); new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; }));
} }
sigmoid_ = config.sigmoid;
}
explicit MulticlassOVA(const std::vector<std::string>& strs) {
num_class_ = -1;
sigmoid_ = -1;
for (auto str : strs) {
auto tokens = Common::Split(str.c_str(), ":");
if (tokens.size() == 2) {
if (tokens[0] == std::string("num_class")) {
Common::Atoi(tokens[1].c_str(), &num_class_);
} else if (tokens[0] == std::string("sigmoid")) {
Common::Atof(tokens[1].c_str(), &sigmoid_);
}
}
}
if (num_class_ < 0) {
Log::Fatal("Objective should contains num_class field");
}
if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
}
} }
~MulticlassOVA() { ~MulticlassOVA() {
...@@ -153,12 +206,32 @@ public: ...@@ -153,12 +206,32 @@ public:
return "multiclassova"; return "multiclassova";
} }
std::vector<double> ConvertOutput(std::vector<double>& input) const override {
for (int i = 0; i < num_class_; ++i) {
input[i] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[i]));
}
return input;
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName() << " ";
str_buf << "num_class:" << num_class_ << " ";
str_buf << "sigmoid:" << sigmoid_;
return str_buf.str();
}
bool SkipEmptyClass() const override { return true; }
int numTreePerIteration() const override { return num_class_; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */ /*! \brief Number of classes */
int num_class_; int num_class_;
std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_; std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
double sigmoid_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment