Commit 945bc150 authored by Guolin Ke's avatar Guolin Ke
Browse files

skip the boosting for the empty class.

parent 82c27d42
...@@ -179,7 +179,11 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -179,7 +179,11 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
class_default_output_[i] = -std::log(kEpsilon); class_default_output_[i] = -std::log(kEpsilon);
} else if (cnt_per_class[i] == 0) { } else if (cnt_per_class[i] == 0) {
class_need_train_[i] = false; class_need_train_[i] = false;
if (sigmoid_ > 0.0f) {
class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f); class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
} else {
class_default_output_[i] = std::log(kEpsilon);
}
} }
} }
} else { } else {
......
...@@ -29,14 +29,25 @@ public: ...@@ -29,14 +29,25 @@ public:
label_ = metadata.label(); label_ = metadata.label();
weights_ = metadata.weights(); weights_ = metadata.weights();
label_int_.resize(num_data_); label_int_.resize(num_data_);
#pragma omp parallel for schedule(static) std::vector<data_size_t> cnt_per_class(num_class_, 0);
for (int i = 0; i < num_data_; ++i) { for (int i = 0; i < num_data_; ++i) {
label_int_[i] = static_cast<int>(label_[i]); label_int_[i] = static_cast<int>(label_[i]);
if (label_int_[i] < 0 || label_int_[i] >= num_class_) { if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]); Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
} }
++cnt_per_class[label_int_[i]];
} }
hessian_nor_ = static_cast<score_t>(num_class_) / (num_class_ - 1); int non_empty_class = 0;
is_empty_class_ = std::vector<bool>(num_class_, false);
for (int i = 0; i < num_class_; ++i) {
if (cnt_per_class[i] > 0) {
++non_empty_class;
} else {
is_empty_class_[i] = true;
}
}
if (non_empty_class < 2) { non_empty_class = 2; }
hessian_nor_ = static_cast<score_t>(non_empty_class) / (non_empty_class - 1);
} }
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
...@@ -51,6 +62,7 @@ public: ...@@ -51,6 +62,7 @@ public:
} }
Common::Softmax(&rec); Common::Softmax(&rec);
for (int k = 0; k < num_class_; ++k) { for (int k = 0; k < num_class_; ++k) {
if (is_empty_class_[k]) { continue; }
auto p = rec[k]; auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
...@@ -72,6 +84,7 @@ public: ...@@ -72,6 +84,7 @@ public:
} }
Common::Softmax(&rec); Common::Softmax(&rec);
for (int k = 0; k < num_class_; ++k) { for (int k = 0; k < num_class_; ++k) {
if (is_empty_class_[k]) { continue; }
auto p = rec[k]; auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
...@@ -100,6 +113,7 @@ private: ...@@ -100,6 +113,7 @@ private:
std::vector<int> label_int_; std::vector<int> label_int_;
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const float* weights_;
std::vector<bool> is_empty_class_;
double softmax_weight_decay_; double softmax_weight_decay_;
score_t hessian_nor_; score_t hessian_nor_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment