Commit 381a945d authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in GetPredictAt for multi class

parent 8aa15e88
...@@ -388,7 +388,6 @@ const score_t* GBDT::GetTrainingScore(data_size_t* out_len) { ...@@ -388,7 +388,6 @@ const score_t* GBDT::GetTrainingScore(data_size_t* out_len) {
void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) { void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size())); CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
std::vector<double> ret;
const score_t* raw_scores = nullptr; const score_t* raw_scores = nullptr;
data_size_t num_data = 0; data_size_t num_data = 0;
...@@ -404,13 +403,13 @@ void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) ...@@ -404,13 +403,13 @@ void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len)
if (num_class_ > 1) { if (num_class_ > 1) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tmp_result; std::vector<double> tmp_result(num_class_);
for (int j = 0; j < num_class_; ++j) { for (int j = 0; j < num_class_; ++j) {
tmp_result.push_back(raw_scores[j * num_data + i]); tmp_result[j] = raw_scores[j * num_data + i];
} }
Common::Softmax(&tmp_result); Common::Softmax(&tmp_result);
for (int j = 0; j < num_class_; ++j) { for (int j = 0; j < num_class_; ++j) {
out_result[j * num_data + i] = static_cast<score_t>(tmp_result[i]); out_result[j * num_data + i] = static_cast<score_t>(tmp_result[j]);
} }
} }
} else if(sigmoid_ > 0.0f){ } else if(sigmoid_ > 0.0f){
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment