"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "c8142e3037adc742d182b453b71a078e6906bb2f"
Commit cccd0587 authored by Guolin Ke's avatar Guolin Ke
Browse files

skip the training of empty class in classification.

parent e404d7cf
...@@ -162,6 +162,37 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -162,6 +162,37 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if (train_data_ != nullptr) { if (train_data_ != nullptr) {
// reset config for tree learner // reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config); tree_learner_->ResetConfig(&new_config->tree_config);
class_need_train_ = std::vector<bool>(num_class_, true);
if (num_class_ > 1 || sigmoid_ > 0) {
// + 1 here for the binary classification
class_default_output_ = std::vector<double>(num_class_ + 1, 0.0f);
std::vector<data_size_t> cnt_per_class(num_class_, 0);
auto label = train_data_->metadata().label();
for (int i = 0; i < num_data_; ++i) {
++cnt_per_class[static_cast<int>(label[i])];
}
if (num_class_ > 1) {
for (int i = 0; i < num_class_; ++i) {
if (cnt_per_class[i] == num_data_) {
Log::Warning("Only contain one class.");
class_need_train_[i] = false;
class_default_output_[i] = -std::log(kEpsilon);
} else if (cnt_per_class[i] == 0) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
}
}
} else {
// binary classification.
if (cnt_per_class[1] == 0) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
} else if (cnt_per_class[1] == num_data_) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(kEpsilon);
}
}
}
} }
gbdt_config_.reset(new_config.release()); gbdt_config_.reset(new_config.release());
} }
...@@ -370,8 +401,11 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -370,8 +401,11 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
#endif #endif
std::unique_ptr<Tree> new_tree( std::unique_ptr<Tree> new_tree(new Tree(2));
tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_)); if (class_need_train_[curr_class]) {
new_tree.reset(
tree_learner_->Train(gradient + curr_class * num_data_, hessian + curr_class * num_data_));
}
#ifdef TIMETAG #ifdef TIMETAG
tree_time += std::chrono::steady_clock::now() - start_time; tree_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
...@@ -383,7 +417,18 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -383,7 +417,18 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// update score // update score
UpdateScore(new_tree.get(), curr_class); UpdateScore(new_tree.get(), curr_class);
UpdateScoreOutOfBag(new_tree.get(), curr_class); UpdateScoreOutOfBag(new_tree.get(), curr_class);
} } else {
// only add default score one-time
if (!class_need_train_[curr_class] && models_.size() < num_class_) {
auto output = class_default_output_[curr_class];
new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0,
output, output, 0, num_data_, 1);
train_score_updater_->AddScore(output, curr_class);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(output, curr_class);
}
}
}
// add model // add model
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
......
...@@ -343,6 +343,8 @@ protected: ...@@ -343,6 +343,8 @@ protected:
std::unique_ptr<Dataset> tmp_subset_; std::unique_ptr<Dataset> tmp_subset_;
bool is_use_subset_; bool is_use_subset_;
bool boost_from_average_; bool boost_from_average_;
std::vector<bool> class_need_train_;
std::vector<double> class_default_output_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -42,6 +42,11 @@ public: ...@@ -42,6 +42,11 @@ public:
++cnt_negative; ++cnt_negative;
} }
} }
if (cnt_negative == 0 || cnt_positive == 0) {
Log::Warning("Only contain one class.");
// not need to boost.
num_data_ = 0;
}
Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative); Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
// use -1 for negative class, and 1 for positive class // use -1 for negative class, and 1 for positive class
label_val_[0] = -1; label_val_[0] = -1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment