Commit 7890d81f authored by Guolin Ke's avatar Guolin Ke
Browse files

fix goss in multi-classification

parent 83dc54e3
...@@ -70,9 +70,12 @@ public: ...@@ -70,9 +70,12 @@ public:
} }
data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer, data_size_t* buffer_right) { data_size_t BaggingHelper(Random& cur_rand, data_size_t start, data_size_t cnt, data_size_t* buffer, data_size_t* buffer_right) {
std::vector<score_t> tmp_gradients(cnt); std::vector<score_t> tmp_gradients(cnt, 0.0f);
for (data_size_t i = 0; i < cnt; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
tmp_gradients[i] = std::fabs(gradients_[start + i] * hessians_[start + i]); for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
int idx = curr_class * num_data_ + start + i;
tmp_gradients[i] += std::fabs(gradients_[idx] * hessians_[idx]);
}
} }
data_size_t top_k = static_cast<data_size_t>(cnt * gbdt_config_->top_rate); data_size_t top_k = static_cast<data_size_t>(cnt * gbdt_config_->top_rate);
data_size_t other_k = static_cast<data_size_t>(cnt * gbdt_config_->other_rate); data_size_t other_k = static_cast<data_size_t>(cnt * gbdt_config_->other_rate);
...@@ -85,7 +88,12 @@ public: ...@@ -85,7 +88,12 @@ public:
data_size_t cur_right_cnt = 0; data_size_t cur_right_cnt = 0;
data_size_t big_weight_cnt = 0; data_size_t big_weight_cnt = 0;
for (data_size_t i = 0; i < cnt; ++i) { for (data_size_t i = 0; i < cnt; ++i) {
if (std::fabs(gradients_[start + i] * hessians_[start + i]) >= threshold) { score_t grad = 0.0f;
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
int idx = curr_class * num_data_ + start + i;
grad += std::fabs(gradients_[idx] * hessians_[idx]);
}
if (grad >= threshold) {
buffer[cur_left_cnt++] = start + i; buffer[cur_left_cnt++] = start + i;
++big_weight_cnt; ++big_weight_cnt;
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment