Unverified Commit 0c4bb89d authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix bug for one-class binary (#1877)

parent e55c8158
...@@ -65,10 +65,11 @@ public: ...@@ -65,10 +65,11 @@ public:
++cnt_negative; ++cnt_negative;
} }
} }
need_train_ = true;
if (cnt_negative == 0 || cnt_positive == 0) { if (cnt_negative == 0 || cnt_positive == 0) {
Log::Warning("Contains only one class"); Log::Warning("Contains only one class");
// not need to boost. // not need to boost.
num_data_ = 0; need_train_ = false;
} }
Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative); Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
// use -1 for negative class, and 1 for positive class // use -1 for negative class, and 1 for positive class
...@@ -91,6 +92,9 @@ public: ...@@ -91,6 +92,9 @@ public:
} }
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
if (!need_train_) {
return;
}
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
...@@ -146,7 +150,7 @@ public: ...@@ -146,7 +150,7 @@ public:
} }
bool ClassNeedTrain(int /*class_id*/) const override { bool ClassNeedTrain(int /*class_id*/) const override {
return num_data_ > 0; return need_train_;
} }
const char* GetName() const override { const char* GetName() const override {
...@@ -185,6 +189,7 @@ private: ...@@ -185,6 +189,7 @@ private:
const label_t* weights_; const label_t* weights_;
double scale_pos_weight_; double scale_pos_weight_;
std::function<bool(label_t)> is_pos_; std::function<bool(label_t)> is_pos_;
bool need_train_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment