"include/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b752170ba950890dd7e211fdecedef4d8deffb8b"
Commit e0ec8661 authored by Guolin Ke's avatar Guolin Ke
Browse files

refactor the counting for classes.

parent ea37b10b
...@@ -166,16 +166,14 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -166,16 +166,14 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) { if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
CHECK(num_tree_per_iteration_ == num_class_); CHECK(num_tree_per_iteration_ == num_class_);
// + 1 here for the binary classification // + 1 here for the binary classification
class_default_output_ = std::vector<double>(num_tree_per_iteration_ + 1, 0.0f); class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_ + 1, 0);
auto label = train_data_->metadata().label(); auto label = train_data_->metadata().label();
for (int i = 0; i < num_data_; ++i) {
int index = static_cast<int>(label[i]);
//Check if user gave a multi-class dataset for a binary class problem.
if (index <= num_tree_per_iteration_)
++cnt_per_class[static_cast<int>(label[i])];
}
if (num_tree_per_iteration_ > 1) { if (num_tree_per_iteration_ > 1) {
// multi-class
std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
for (data_size_t i = 0; i < num_data_; ++i) {
++cnt_per_class[static_cast<int>(label[i])];
}
for (int i = 0; i < num_tree_per_iteration_; ++i) { for (int i = 0; i < num_tree_per_iteration_; ++i) {
if (cnt_per_class[i] == num_data_) { if (cnt_per_class[i] == num_data_) {
class_need_train_[i] = false; class_need_train_[i] = false;
...@@ -186,11 +184,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -186,11 +184,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
} }
} }
} else { } else {
// binary classification. // binary class
if (cnt_per_class[1] == 0) { data_size_t cnt_pos = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (label[i] > 0) {
++cnt_pos;
}
}
if (cnt_pos == 0) {
class_need_train_[0] = false; class_need_train_[0] = false;
class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f); class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
} else if (cnt_per_class[1] == num_data_) { } else if (cnt_pos == num_data_) {
class_need_train_[0] = false; class_need_train_[0] = false;
class_default_output_[0] = -std::log(kEpsilon); class_default_output_[0] = -std::log(kEpsilon);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment