"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "2a8d38c5551c3c911d3d8a488dcc8127b511512e"
Commit 07f709b9 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a bug in multi-class

parent fff84463
...@@ -195,13 +195,21 @@ private: ...@@ -195,13 +195,21 @@ private:
inline double Tree::Predict(const double* feature_values) const { inline double Tree::Predict(const double* feature_values) const {
if (num_leaves_ > 1) {
int leaf = GetLeaf(feature_values); int leaf = GetLeaf(feature_values);
return LeafOutput(leaf); return LeafOutput(leaf);
} else {
return 0.0f;
}
} }
inline int Tree::PredictLeafIndex(const double* feature_values) const { inline int Tree::PredictLeafIndex(const double* feature_values) const {
if (num_leaves_ > 1) {
int leaf = GetLeaf(feature_values); int leaf = GetLeaf(feature_values);
return leaf; return leaf;
} else {
return 0;
}
} }
inline int Tree::GetLeaf(const double* feature_values) const { inline int Tree::GetLeaf(const double* feature_values) const {
......
...@@ -273,6 +273,9 @@ inline static std::string ArrayToString(const std::vector<T>& arr, size_t n, cha ...@@ -273,6 +273,9 @@ inline static std::string ArrayToString(const std::vector<T>& arr, size_t n, cha
template<typename T> template<typename T>
inline static std::vector<T> StringToArray(const std::string& str, char delimiter, size_t n) { inline static std::vector<T> StringToArray(const std::string& str, char delimiter, size_t n) {
if (n == 0) {
return std::vector<T>();
}
std::vector<std::string> strs = Split(str.c_str(), delimiter); std::vector<std::string> strs = Split(str.c_str(), delimiter);
if (strs.size() != n) { if (strs.size() != n) {
Log::Fatal("StringToArray error, size doesn't match."); Log::Fatal("StringToArray error, size doesn't match.");
......
...@@ -336,6 +336,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -336,6 +336,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
sub_gradient_time += std::chrono::steady_clock::now() - start_time; sub_gradient_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
} }
bool shouldContinue = false;
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
#ifdef TIMETAG #ifdef TIMETAG
start_time = std::chrono::steady_clock::now(); start_time = std::chrono::steady_clock::now();
...@@ -345,10 +346,9 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -345,10 +346,9 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
#ifdef TIMETAG #ifdef TIMETAG
tree_time += std::chrono::steady_clock::now() - start_time; tree_time += std::chrono::steady_clock::now() - start_time;
#endif #endif
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) { if (new_tree->num_leaves() > 1) {
Log::Info("Stopped training because there are no more leaves that meet the split requirements."); shouldContinue = true;
return true;
} }
// shrinkage by learning rate // shrinkage by learning rate
...@@ -360,6 +360,13 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -360,6 +360,13 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// add model // add model
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
if (!shouldContinue) {
Log::Warning("Stopped training because there are no more leaves that meet the split requirements.");
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
models_.pop_back();
}
return true;
}
++iter_; ++iter_;
if (is_eval) { if (is_eval) {
return EvalAndCheckEarlyStopping(); return EvalAndCheckEarlyStopping();
......
...@@ -97,6 +97,7 @@ int Tree::Split(int leaf, int feature, BinType bin_type, uint32_t threshold_bin, ...@@ -97,6 +97,7 @@ int Tree::Split(int leaf, int feature, BinType bin_type, uint32_t threshold_bin,
} }
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const { void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
if (num_leaves_ <= 1) { return; }
if (has_categorical_) { if (has_categorical_) {
if (data->num_features() > num_leaves_ - 1) { if (data->num_features() > num_leaves_ - 1) {
Threading::For<data_size_t>(0, num_data, Threading::For<data_size_t>(0, num_data,
...@@ -193,6 +194,7 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl ...@@ -193,6 +194,7 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl
void Tree::AddPredictionToScore(const Dataset* data, void Tree::AddPredictionToScore(const Dataset* data,
const data_size_t* used_data_indices, const data_size_t* used_data_indices,
data_size_t num_data, double* score) const { data_size_t num_data, double* score) const {
if (num_leaves_ <= 1) { return; }
if (has_categorical_) { if (has_categorical_) {
if (data->num_features() > num_leaves_ - 1) { if (data->num_features() > num_leaves_ - 1) {
Threading::For<data_size_t>(0, num_data, Threading::For<data_size_t>(0, num_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment