Commit d93eb338 authored by Guolin Ke's avatar Guolin Ke
Browse files

add un_balance to multi class (#314)

* add un_balance to multi class

* change the calculation of label weights
parent 1bf7bbd0
...@@ -14,6 +14,7 @@ class MulticlassLogloss: public ObjectiveFunction { ...@@ -14,6 +14,7 @@ class MulticlassLogloss: public ObjectiveFunction {
public: public:
explicit MulticlassLogloss(const ObjectiveConfig& config) { explicit MulticlassLogloss(const ObjectiveConfig& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
is_unbalance_ = config.is_unbalance;
} }
~MulticlassLogloss() { ~MulticlassLogloss() {
...@@ -24,12 +25,25 @@ public: ...@@ -24,12 +25,25 @@ public:
label_ = metadata.label(); label_ = metadata.label();
weights_ = metadata.weights(); weights_ = metadata.weights();
label_int_.resize(num_data_); label_int_.resize(num_data_);
for (int i = 0; i < num_data_; ++i){ #pragma omp parallel for schedule(static)
for (int i = 0; i < num_data_; ++i) {
label_int_[i] = static_cast<int>(label_[i]); label_int_[i] = static_cast<int>(label_[i]);
if (label_int_[i] < 0 || label_int_[i] >= num_class_) { if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]); Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
} }
} }
label_pos_weights_ = std::vector<float>(num_class_, 1);
if (is_unbalance_) {
std::vector<int> cnts(num_class_, 0);
for (int i = 0; i < num_data_; ++i) {
++cnts[label_int_[i]];
}
for (int i = 0; i < num_class_; ++i) {
int cnt_cur = cnts[i];
int cnt_other = (num_data_ - cnts[i]);
label_pos_weights_[i] = static_cast<float>(cnt_other) / cnt_cur;
}
}
} }
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
...@@ -46,13 +60,14 @@ public: ...@@ -46,13 +60,14 @@ public:
auto p = rec[k]; auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>(p - 1.0f); gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p))* label_pos_weights_[k];
} else { } else {
gradients[idx] = static_cast<score_t>(p); gradients[idx] = static_cast<score_t>(p);
}
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p)); hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p));
} }
} }
}
} else { } else {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
...@@ -66,12 +81,14 @@ public: ...@@ -66,12 +81,14 @@ public:
auto p = rec[k]; auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]); gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
} else { } else {
gradients[idx] = static_cast<score_t>(p * weights_[i]); gradients[idx] = static_cast<score_t>(p * weights_[i]);
}
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]); hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]);
} }
}
} }
} }
} }
...@@ -91,6 +108,9 @@ private: ...@@ -91,6 +108,9 @@ private:
std::vector<int> label_int_; std::vector<int> label_int_;
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const float* weights_;
/*! \brief Weights for label */
std::vector<float> label_pos_weights_;
bool is_unbalance_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment