Commit 381a945d authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in GetPredictAt for multi class

parent 8aa15e88
......@@ -388,7 +388,6 @@ const score_t* GBDT::GetTrainingScore(data_size_t* out_len) {
void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_metrics_.size()));
std::vector<double> ret;
const score_t* raw_scores = nullptr;
data_size_t num_data = 0;
......@@ -404,13 +403,13 @@ void GBDT::GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len)
if (num_class_ > 1) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tmp_result;
std::vector<double> tmp_result(num_class_);
for (int j = 0; j < num_class_; ++j) {
tmp_result.push_back(raw_scores[j * num_data + i]);
tmp_result[j] = raw_scores[j * num_data + i];
}
Common::Softmax(&tmp_result);
for (int j = 0; j < num_class_; ++j) {
out_result[j * num_data + i] = static_cast<score_t>(tmp_result[i]);
out_result[j * num_data + i] = static_cast<score_t>(tmp_result[j]);
}
}
} else if(sigmoid_ > 0.0f){
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment