Commit 07f709b9 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a bug in multi-class

parent fff84463
......@@ -195,13 +195,21 @@ private:
inline double Tree::Predict(const double* feature_values) const {
int leaf = GetLeaf(feature_values);
return LeafOutput(leaf);
if (num_leaves_ > 1) {
int leaf = GetLeaf(feature_values);
return LeafOutput(leaf);
} else {
return 0.0f;
}
}
inline int Tree::PredictLeafIndex(const double* feature_values) const {
int leaf = GetLeaf(feature_values);
return leaf;
if (num_leaves_ > 1) {
int leaf = GetLeaf(feature_values);
return leaf;
} else {
return 0;
}
}
inline int Tree::GetLeaf(const double* feature_values) const {
......
......@@ -273,6 +273,9 @@ inline static std::string ArrayToString(const std::vector<T>& arr, size_t n, cha
template<typename T>
inline static std::vector<T> StringToArray(const std::string& str, char delimiter, size_t n) {
if (n == 0) {
return std::vector<T>();
}
std::vector<std::string> strs = Split(str.c_str(), delimiter);
if (strs.size() != n) {
Log::Fatal("StringToArray error, size doesn't match.");
......
......@@ -336,6 +336,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
sub_gradient_time += std::chrono::steady_clock::now() - start_time;
#endif
}
bool shouldContinue = false;
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
......@@ -345,10 +346,9 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
#ifdef TIMETAG
tree_time += std::chrono::steady_clock::now() - start_time;
#endif
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) {
Log::Info("Stopped training because there are no more leaves that meet the split requirements.");
return true;
if (new_tree->num_leaves() > 1) {
shouldContinue = true;
}
// shrinkage by learning rate
......@@ -360,6 +360,13 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// add model
models_.push_back(std::move(new_tree));
}
if (!shouldContinue) {
Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back();
}
return true;
}
++iter_;
if (is_eval) {
return EvalAndCheckEarlyStopping();
......
......@@ -97,6 +97,7 @@ int Tree::Split(int leaf, int feature, BinType bin_type, uint32_t threshold_bin,
}
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
if (num_leaves_ <= 1) { return; }
if (has_categorical_) {
if (data->num_features() > num_leaves_ - 1) {
Threading::For<data_size_t>(0, num_data,
......@@ -193,6 +194,7 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl
void Tree::AddPredictionToScore(const Dataset* data,
const data_size_t* used_data_indices,
data_size_t num_data, double* score) const {
if (num_leaves_ <= 1) { return; }
if (has_categorical_) {
if (data->num_features() > num_leaves_ - 1) {
Threading::For<data_size_t>(0, num_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment