Commit cccd0587 authored by Guolin Ke's avatar Guolin Ke
Browse files

skip the training of empty class in classification.

parent e404d7cf
......@@ -162,6 +162,37 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if (train_data_ != nullptr) {
// reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config);
class_need_train_ = std::vector<bool>(num_class_, true);
if (num_class_ > 1 || sigmoid_ > 0) {
// + 1 here for the binary classification
class_default_output_ = std::vector<double>(num_class_ + 1, 0.0f);
std::vector<data_size_t> cnt_per_class(num_class_, 0);
auto label = train_data_->metadata().label();
for (int i = 0; i < num_data_; ++i) {
++cnt_per_class[static_cast<int>(label[i])];
}
if (num_class_ > 1) {
for (int i = 0; i < num_class_; ++i) {
if (cnt_per_class[i] == num_data_) {
Log::Warning("Only contain one class.");
class_need_train_[i] = false;
class_default_output_[i] = -std::log(kEpsilon);
} else if (cnt_per_class[i] == 0) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
}
}
} else {
// binary classification.
if (cnt_per_class[1] == 0) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
} else if (cnt_per_class[1] == num_data_) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(kEpsilon);
}
}
}
}
gbdt_config_.reset(new_config.release());
}
......@@ -370,8 +401,11 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
#endif
std::unique_ptr<Tree> new_tree(
tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
std::unique_ptr<Tree> new_tree(new Tree(2));
if (class_need_train_[curr_class]) {
new_tree.reset(
tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
}
#ifdef TIMETAG
tree_time += std::chrono::steady_clock::now() - start_time;
#endif
......@@ -383,7 +417,18 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// update score
UpdateScore(new_tree.get(), curr_class);
UpdateScoreOutOfBag(new_tree.get(), curr_class);
}
} else {
// only add default score one-time
if (!class_need_train_[curr_class] && models_.size() < num_class_) {
auto output = class_default_output_[curr_class];
new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0,
output, output, 0, num_data_, 1);
train_score_updater_->AddScore(output, curr_class);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(output, curr_class);
}
}
}
// add model
models_.push_back(std::move(new_tree));
}
......
......@@ -343,6 +343,8 @@ protected:
std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_;
bool boost_from_average_;
std::vector<bool> class_need_train_;
std::vector<double> class_default_output_;
};
} // namespace LightGBM
......
......@@ -42,6 +42,11 @@ public:
++cnt_negative;
}
}
if (cnt_negative == 0 || cnt_positive == 0) {
Log::Warning("Only contain one class.");
// not need to boost.
num_data_ = 0;
}
Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
// use -1 for negative class, and 1 for positive class
label_val_[0] = -1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment