Commit 82c27d42 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix prediction of ova multi-class.

parent 512e0c34
...@@ -40,15 +40,15 @@ GBDT::GBDT() ...@@ -40,15 +40,15 @@ GBDT::GBDT()
shrinkage_rate_(0.1f), shrinkage_rate_(0.1f),
num_init_iteration_(0), num_init_iteration_(0),
boost_from_average_(false) { boost_from_average_(false) {
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
} }
GBDT::~GBDT() { GBDT::~GBDT() {
#ifdef TIMETAG #ifdef TIMETAG
Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3); Log::Info("GBDT::boosting costs %f", boosting_time * 1e-3);
Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3); Log::Info("GBDT::train_score costs %f", train_score_time * 1e-3);
Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3); Log::Info("GBDT::out_of_bag_score costs %f", out_of_bag_score_time * 1e-3);
...@@ -57,7 +57,7 @@ GBDT::~GBDT() { ...@@ -57,7 +57,7 @@ GBDT::~GBDT() {
Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3); Log::Info("GBDT::bagging costs %f", bagging_time * 1e-3);
Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3); Log::Info("GBDT::sub_gradient costs %f", sub_gradient_time * 1e-3);
Log::Info("GBDT::tree costs %f", tree_time * 1e-3); Log::Info("GBDT::tree costs %f", tree_time * 1e-3);
#endif #endif
} }
void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
...@@ -85,7 +85,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -85,7 +85,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
sigmoid_ = -1.0f; sigmoid_ = -1.0f;
if (object_function_ != nullptr if (object_function_ != nullptr
&& std::string(object_function_->GetName()) == std::string("binary")) { && (std::string(object_function_->GetName()) == std::string("binary")
|| std::string(object_function_->GetName()) == std::string("multiclassova"))) {
// only binary classification need sigmoid transform // only binary classification need sigmoid transform
sigmoid_ = new_config->sigmoid; sigmoid_ = new_config->sigmoid;
} }
...@@ -174,7 +175,6 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -174,7 +175,6 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if (num_class_ > 1) { if (num_class_ > 1) {
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
if (cnt_per_class[i] == num_data_) { if (cnt_per_class[i] == num_data_) {
Log::Warning("Only contain one class.");
class_need_train_[i] = false; class_need_train_[i] = false;
class_default_output_[i] = -std::log(kEpsilon); class_default_output_[i] = -std::log(kEpsilon);
} else if (cnt_per_class[i] == 0) { } else if (cnt_per_class[i] == 0) {
...@@ -317,16 +317,16 @@ void GBDT::Bagging(int iter) { ...@@ -317,16 +317,16 @@ void GBDT::Bagging(int iter) {
} }
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) { void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
// we need to predict out-of-bag scores of data for boosting // we need to predict out-of-bag scores of data for boosting
if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) { if (num_data_ - bag_data_cnt_ > 0 && !is_use_subset_) {
train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class); train_score_updater_->AddScore(tree, bag_data_indices_.data() + bag_data_cnt_, num_data_ - bag_data_cnt_, curr_class);
} }
#ifdef TIMETAG #ifdef TIMETAG
out_of_bag_score_time += std::chrono::steady_clock::now() - start_time; out_of_bag_score_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
} }
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) { bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
...@@ -364,14 +364,14 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -364,14 +364,14 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
boosting_time += std::chrono::steady_clock::now() - start_time; boosting_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
} }
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
// bagging logic // bagging logic
Bagging(iter_); Bagging(iter_);
#ifdef TIMETAG #ifdef TIMETAG
bagging_time += std::chrono::steady_clock::now() - start_time; bagging_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
if (is_use_subset_ && bag_data_cnt_ < num_data_) { if (is_use_subset_ && bag_data_cnt_ < num_data_) {
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
...@@ -469,14 +469,14 @@ void GBDT::RollbackOneIter() { ...@@ -469,14 +469,14 @@ void GBDT::RollbackOneIter() {
bool GBDT::EvalAndCheckEarlyStopping() { bool GBDT::EvalAndCheckEarlyStopping() {
bool is_met_early_stopping = false; bool is_met_early_stopping = false;
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
// print message for metric // print message for metric
auto best_msg = OutputMetric(iter_); auto best_msg = OutputMetric(iter_);
#ifdef TIMETAG #ifdef TIMETAG
metric_time += std::chrono::steady_clock::now() - start_time; metric_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
is_met_early_stopping = !best_msg.empty(); is_met_early_stopping = !best_msg.empty();
if (is_met_early_stopping) { if (is_met_early_stopping) {
Log::Info("Early stopping at iteration %d, the best iteration round is %d", Log::Info("Early stopping at iteration %d, the best iteration round is %d",
...@@ -491,28 +491,28 @@ bool GBDT::EvalAndCheckEarlyStopping() { ...@@ -491,28 +491,28 @@ bool GBDT::EvalAndCheckEarlyStopping() {
} }
void GBDT::UpdateScore(const Tree* tree, const int curr_class) { void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
#ifdef TIMETAG #ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
#endif #endif
// update training score // update training score
if (!is_use_subset_) { if (!is_use_subset_) {
train_score_updater_->AddScore(tree_learner_.get(), tree, curr_class); train_score_updater_->AddScore(tree_learner_.get(), tree, curr_class);
} else { } else {
train_score_updater_->AddScore(tree, curr_class); train_score_updater_->AddScore(tree, curr_class);
} }
#ifdef TIMETAG #ifdef TIMETAG
train_score_time += std::chrono::steady_clock::now() - start_time; train_score_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
#endif #endif
// update validation score // update validation score
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(tree, curr_class); score_updater->AddScore(tree, curr_class);
} }
#ifdef TIMETAG #ifdef TIMETAG
valid_score_time += std::chrono::steady_clock::now() - start_time; valid_score_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
} }
std::string GBDT::OutputMetric(int iter) { std::string GBDT::OutputMetric(int iter) {
...@@ -617,7 +617,15 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { ...@@ -617,7 +617,15 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
num_data = valid_score_updater_[used_idx]->num_data(); num_data = valid_score_updater_[used_idx]->num_data();
*out_len = static_cast<int64_t>(num_data) * num_class_; *out_len = static_cast<int64_t>(num_data) * num_class_;
} }
if (num_class_ > 1) { if (sigmoid_ > 0.0f) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
for (int j = 0; j < num_class_; ++j) {
out_result[i + j * num_data_] = static_cast<double>(
1.0f / (1.0f + std::exp(-sigmoid_ * raw_scores[i + j * num_data_])));
}
}
} else if (num_class_ > 1) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tmp_result(num_class_); std::vector<double> tmp_result(num_class_);
...@@ -629,11 +637,6 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { ...@@ -629,11 +637,6 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
out_result[j * num_data + i] = static_cast<double>(tmp_result[j]); out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
} }
} }
} else if (sigmoid_ > 0.0f) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
out_result[i] = static_cast<double>(1.0f / (1.0f + std::exp(-sigmoid_ * raw_scores[i])));
}
} else { } else {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
...@@ -855,7 +858,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const { ...@@ -855,7 +858,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
} }
// sort the importance // sort the importance
std::sort(pairs.begin(), pairs.end(), std::sort(pairs.begin(), pairs.end(),
[] (const std::pair<size_t, std::string>& lhs, [](const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) { const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first; return lhs.first > rhs.first;
}); });
...@@ -880,8 +883,10 @@ std::vector<double> GBDT::Predict(const double* value) const { ...@@ -880,8 +883,10 @@ std::vector<double> GBDT::Predict(const double* value) const {
} }
} }
// if need sigmoid transform // if need sigmoid transform
if (sigmoid_ > 0 && num_class_ == 1) { if (sigmoid_ > 0) {
ret[0] = 1.0f / (1.0f + std::exp(-sigmoid_ * ret[0])); for (int j = 0; j < num_class_; ++j) {
ret[j] = 1.0f / (1.0f + std::exp(-sigmoid_ * ret[j]));
}
} else if (num_class_ > 1) { } else if (num_class_ > 1) {
Common::Softmax(&ret); Common::Softmax(&ret);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment