Commit 8a19834a authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[MRG] translate model to if-else (#469)

* translate model to if-else

* support multiclass and predictleaf

* remove java option for now

* support multi-thread

* add task:convert_model
parent 7f94fd9c
...@@ -59,6 +59,9 @@ private: ...@@ -59,6 +59,9 @@ private:
/*! \brief Main predicting logic */ /*! \brief Main predicting logic */
void Predict(); void Predict();
/*! \brief Main Convert model logic */
void ConvertModel();
/*! \brief All configs */ /*! \brief All configs */
OverallConfig config_; OverallConfig config_;
/*! \brief Training data */ /*! \brief Training data */
...@@ -80,6 +83,8 @@ inline void Application::Run() { ...@@ -80,6 +83,8 @@ inline void Application::Run() {
if (config_.task_type == TaskType::kPredict) { if (config_.task_type == TaskType::kPredict) {
InitPredict(); InitPredict();
Predict(); Predict();
} else if (config_.task_type == TaskType::kConvertModel) {
ConvertModel();
} else { } else {
InitTrain(); InitTrain();
Train(); Train();
......
...@@ -136,10 +136,26 @@ public: ...@@ -136,10 +136,26 @@ public:
/*! /*!
* \brief Dump model to json format string * \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \return Json format string of model * \return Json format string of model
*/ */
virtual std::string DumpModel(int num_iteration) const = 0; virtual std::string DumpModel(int num_iteration) const = 0;
/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \return if-else format codes of model
*/
virtual std::string ModelToIfElse(int num_iteration) const = 0;
/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
virtual bool SaveModelToIfElse(int num_iteration, const char* filename) const = 0;
/*! /*!
* \brief Save model to file * \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all * \param num_used_model Number of model that want to save, -1 means save all
......
...@@ -79,7 +79,7 @@ public: ...@@ -79,7 +79,7 @@ public:
/*! \brief Types of tasks */ /*! \brief Types of tasks */
enum TaskType { enum TaskType {
kTrain, kPredict kTrain, kPredict, kConvertModel
}; };
/*! \brief Config for input and output files */ /*! \brief Config for input and output files */
...@@ -93,6 +93,7 @@ public: ...@@ -93,6 +93,7 @@ public:
int snapshot_freq = 100; int snapshot_freq = 100;
std::string output_model = "LightGBM_model.txt"; std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt"; std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "LightGBM_convert_model.cpp";
std::string input_model = ""; std::string input_model = "";
int verbosity = 1; int verbosity = 1;
int num_iteration_predict = -1; int num_iteration_predict = -1;
...@@ -269,6 +270,7 @@ public: ...@@ -269,6 +270,7 @@ public:
ObjectiveConfig objective_config; ObjectiveConfig objective_config;
std::vector<std::string> metric_types; std::vector<std::string> metric_types;
MetricConfig metric_config; MetricConfig metric_config;
std::string convert_model_language = "";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
......
...@@ -119,6 +119,9 @@ public: ...@@ -119,6 +119,9 @@ public:
/*! \brief Serialize this object to json*/ /*! \brief Serialize this object to json*/
std::string ToJSON(); std::string ToJSON();
/*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index);
template<typename T> template<typename T>
static bool CategoricalDecision(T fval, T threshold) { static bool CategoricalDecision(T fval, T threshold) {
if (static_cast<int>(fval) == static_cast<int>(threshold)) { if (static_cast<int>(fval) == static_cast<int>(threshold)) {
...@@ -160,6 +163,9 @@ private: ...@@ -160,6 +163,9 @@ private:
/*! \brief Serialize one node to json*/ /*! \brief Serialize one node to json*/
inline std::string NodeToJSON(int index); inline std::string NodeToJSON(int index);
/*! \brief Serialize one node to if-else statement*/
inline std::string NodeToIfElse(int index, bool is_predict_leaf_index);
/*! \brief Number of max leaves*/ /*! \brief Number of max leaves*/
int max_leaves_; int max_leaves_;
/*! \brief Number of current levas*/ /*! \brief Number of current levas*/
......
...@@ -32,7 +32,7 @@ Application::Application(int argc, char** argv) { ...@@ -32,7 +32,7 @@ Application::Application(int argc, char** argv) {
if (config_.num_threads > 0) { if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads); omp_set_num_threads(config_.num_threads);
} }
if (config_.io_config.data_filename.size() == 0) { if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit"); Log::Fatal("No training/prediction data, application quit");
} }
} }
...@@ -239,10 +239,13 @@ void Application::Train() { ...@@ -239,10 +239,13 @@ void Application::Train() {
} }
// save model to file // save model to file
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str()); boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}
Log::Info("Finished training"); Log::Info("Finished training");
} }
void Application::Predict() { void Application::Predict() {
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
...@@ -258,6 +261,13 @@ void Application::InitPredict() { ...@@ -258,6 +261,13 @@ void Application::InitPredict() {
Log::Info("Finished initializing prediction"); Log::Info("Finished initializing prediction");
} }
void Application::ConvertModel() {
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}
template<typename T> template<typename T>
T Application::GlobalSyncUpByMin(T& local) { T Application::GlobalSyncUpByMin(T& local) {
T global = local; T global = local;
......
...@@ -700,6 +700,99 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -700,6 +700,99 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str(); return str_buf.str();
} }
std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream str_buf;
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_iteration += boost_from_average_ ? 1 : 0;
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}
// PredictRaw
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, false) << std::endl;
}
str_buf << "double (*PredictTreePtr[])(const double*) = { ";
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << " , ";
}
str_buf << "PredictTree" << i;
}
str_buf << " };" << std::endl << std::endl;
std::stringstream pred_str_buf;
pred_str_buf << "\t" << "if (num_threads_ <= num_tree_per_iteration_) {" << std::endl;
pred_str_buf << "\t\t" << "#pragma omp parallel for schedule(static)" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "} else {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "double t = 0.0f;" << std::endl;
pred_str_buf << "\t\t\t" << "#pragma omp parallel for schedule(static) reduction(+:t)" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t\t\t" << "t += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] = t;" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;
str_buf << "void GBDT::PredictRaw(const double* features, double *output) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "}" << std::endl;
str_buf << std::endl;
// Predict
str_buf << "void GBDT::Predict(const double* features, double *output) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << std::endl;
// PredictLeafIndex
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, true) << std::endl;
}
str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << " , ";
}
str_buf << "PredictTree" << i << "Leaf";
}
str_buf << " };" << std::endl << std::endl;
str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
str_buf << "\t" << "#pragma omp parallel for schedule(static)" << std::endl;
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
return str_buf.str();
}
bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
output_file.open(filename);
output_file << ModelToIfElse(num_iteration);
output_file.close();
return (bool)output_file;
}
std::string GBDT::SaveModelToString(int num_iteration) const { std::string GBDT::SaveModelToString(int num_iteration) const {
std::stringstream ss; std::stringstream ss;
......
...@@ -144,15 +144,31 @@ public: ...@@ -144,15 +144,31 @@ public:
/*! /*!
* \brief Dump model to json format string * \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \return Json format string of model * \return Json format string of model
*/ */
std::string DumpModel(int num_iteration) const override; std::string DumpModel(int num_iteration) const override;
/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \return if-else format codes of model
*/
std::string ModelToIfElse(int num_iteration) const override;
/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
bool SaveModelToIfElse(int num_iteration, const char* filename) const override;
/*! /*!
* \brief Save model to file * \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all * \param num_used_model Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param filename Filename that want to save to * \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/ */
virtual bool SaveModelToFile(int num_iterations, const char* filename) const override; virtual bool SaveModelToFile(int num_iterations, const char* filename) const override;
......
...@@ -35,6 +35,7 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par ...@@ -35,6 +35,7 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) { void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types // load main config types
GetInt(params, "num_threads", &num_threads); GetInt(params, "num_threads", &num_threads);
GetString(params, "convert_model_language", &convert_model_language);
// generate seeds by seed. // generate seeds by seed.
if (GetInt(params, "seed", &seed)) { if (GetInt(params, "seed", &seed)) {
...@@ -129,6 +130,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin ...@@ -129,6 +130,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
} else if (value == std::string("predict") || value == std::string("prediction") } else if (value == std::string("predict") || value == std::string("prediction")
|| value == std::string("test")) { || value == std::string("test")) {
task_type = TaskType::kPredict; task_type = TaskType::kPredict;
} else if (value == std::string("convert_model")) {
task_type = TaskType::kConvertModel;
} else { } else {
Log::Fatal("Unknown task type %s", value.c_str()); Log::Fatal("Unknown task type %s", value.c_str());
} }
...@@ -210,6 +213,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -210,6 +213,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "snapshot_freq", &snapshot_freq); GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "output_model", &output_model); GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
GetString(params, "convert_model", &convert_model);
GetString(params, "output_result", &output_result); GetString(params, "output_result", &output_result);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "valid_data", &tmp_str)) { if (GetString(params, "valid_data", &tmp_str)) {
......
...@@ -368,6 +368,54 @@ std::string Tree::NodeToJSON(int index) { ...@@ -368,6 +368,54 @@ std::string Tree::NodeToJSON(int index) {
return str_buf.str(); return str_buf.str();
} }
std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) {
std::stringstream str_buf;
str_buf << "double PredictTree" << index;
if (is_predict_leaf_index) {
str_buf << "Leaf";
}
str_buf << "(const double* arr) { ";
if (num_leaves_ == 1) {
str_buf << "return 0";
} else {
str_buf << NodeToIfElse(0, is_predict_leaf_index);
}
str_buf << " }" << std::endl;
return str_buf.str();
}
std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) {
std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
if (index >= 0) {
// non-leaf
str_buf << "if ( arr[" << split_feature_[index] << "] ";
if (decision_type_[index] == 0) {
str_buf << "<";
} else {
str_buf << "=";
}
str_buf << "= " << threshold_[index] << " ) { ";
// left subtree
str_buf << NodeToIfElse(left_child_[index], is_predict_leaf_index);
str_buf << " } else { ";
// right subtree
str_buf << NodeToIfElse(right_child_[index], is_predict_leaf_index);
str_buf << " }";
} else {
// leaf
str_buf << "return ";
if (is_predict_leaf_index) {
str_buf << ~index;
} else {
str_buf << leaf_value_[~index];
}
str_buf << ";";
}
return str_buf.str();
}
Tree::Tree(const std::string& str) { Tree::Tree(const std::string& str) {
std::vector<std::string> lines = Common::Split(str.c_str(), '\n'); std::vector<std::string> lines = Common::Split(str.c_str(), '\n');
std::unordered_map<std::string, std::string> key_vals; std::unordered_map<std::string, std::string> key_vals;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment