Commit 7890d81f authored by Guolin Ke's avatar Guolin Ke
Browse files

fix goss in multi-classification

parent 83dc54e3
......@@ -70,9 +70,12 @@ public:
}
data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer, data_size_t* buffer_right) {
std::vector<score_t> tmp_gradients(cnt);
std::vector<score_t> tmp_gradients(cnt, 0.0f);
for (data_size_t i = 0; i < cnt; ++i) {
tmp_gradients[i] = std::fabs(gradients_[start + i] * hessians_[start + i]);
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
int idx = curr_class * num_data_ + start + i;
tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]);
}
}
data_size_t top_k = static_cast<data_size_t>(cnt * gbdt_config_->top_rate);
data_size_t other_k = static_cast<data_size_t>(cnt * gbdt_config_->other_rate);
......@@ -85,7 +88,12 @@ public:
data_size_t cur_right_cnt = 0;
data_size_t big_weight_cnt = 0;
for (data_size_t i = 0; i < cnt; ++i) {
if (std::fabs(gradients_[start + i] * hessians_[start + i]) >= threshold) {
score_t grad = 0.0f;
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
int idx = curr_class * num_data_ + start + i;
grad += std::fabs(gradients_[idx] * hessians_[idx]);
}
if (grad >= threshold) {
buffer[cur_left_cnt++] = start + i;
++big_weight_cnt;
} else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment