Commit 31302bf7 authored by Guolin Ke's avatar Guolin Ke
Browse files

force set to zero to prediction buffer.

parent d0858b36
...@@ -774,6 +774,7 @@ std::string GBDT::ModelToIfElse(int num_iteration) const { ...@@ -774,6 +774,7 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream pred_str_buf; std::stringstream pred_str_buf;
pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl; pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << std::endl;
pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl; pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl; pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl; pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
......
...@@ -8,12 +8,13 @@ namespace LightGBM { ...@@ -8,12 +8,13 @@ namespace LightGBM {
void GBDT::PredictRaw(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const { void GBDT::PredictRaw(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
int early_stop_round_counter = 0; int early_stop_round_counter = 0;
// set zero
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
for (int i = 0; i < num_iteration_for_pred_; ++i) { for (int i = 0; i < num_iteration_for_pred_; ++i) {
// predict all the trees for one iteration // predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) { for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(features); output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(features);
} }
// check early stopping // check early stopping
++early_stop_round_counter; ++early_stop_round_counter;
if (early_stop->round_period == early_stop_round_counter) { if (early_stop->round_period == early_stop_round_counter) {
...@@ -27,7 +28,6 @@ void GBDT::PredictRaw(const double* features, double* output, const PredictionEa ...@@ -27,7 +28,6 @@ void GBDT::PredictRaw(const double* features, double* output, const PredictionEa
void GBDT::Predict(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const { void GBDT::Predict(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
PredictRaw(features, output, early_stop); PredictRaw(features, output, early_stop);
if (objective_function_ != nullptr) { if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output); objective_function_->ConvertOutput(output, output);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment