Commit 8a19834a authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[MRG] translate model to if-else (#469)

* translate model to if-else

* support multiclass and predictleaf

* remove java option for now

* support multi-thread

* add task:convert_model
parent 7f94fd9c
......@@ -59,6 +59,9 @@ private:
/*! \brief Main predicting logic */
void Predict();
/*! \brief Main Convert model logic */
void ConvertModel();
/*! \brief All configs */
OverallConfig config_;
/*! \brief Training data */
......@@ -80,6 +83,8 @@ inline void Application::Run() {
if (config_.task_type == TaskType::kPredict) {
InitPredict();
Predict();
} else if (config_.task_type == TaskType::kConvertModel) {
ConvertModel();
} else {
InitTrain();
Train();
......
......@@ -136,10 +136,26 @@ public:
/*!
* \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \return Json format string of model
*/
virtual std::string DumpModel(int num_iteration) const = 0;
/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \return if-else format codes of model
*/
virtual std::string ModelToIfElse(int num_iteration) const = 0;
/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
virtual bool SaveModelToIfElse(int num_iteration, const char* filename) const = 0;
/*!
* \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all
......
......@@ -79,7 +79,7 @@ public:
/*! \brief Types of tasks */
enum TaskType {
kTrain, kPredict
kTrain, kPredict, kConvertModel
};
/*! \brief Config for input and output files */
......@@ -93,6 +93,7 @@ public:
int snapshot_freq = 100;
std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "LightGBM_convert_model.cpp";
std::string input_model = "";
int verbosity = 1;
int num_iteration_predict = -1;
......@@ -269,6 +270,7 @@ public:
ObjectiveConfig objective_config;
std::vector<std::string> metric_types;
MetricConfig metric_config;
std::string convert_model_language = "";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
......
......@@ -119,6 +119,9 @@ public:
/*! \brief Serialize this object to json*/
std::string ToJSON();
/*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index);
template<typename T>
static bool CategoricalDecision(T fval, T threshold) {
if (static_cast<int>(fval) == static_cast<int>(threshold)) {
......@@ -160,6 +163,9 @@ private:
/*! \brief Serialize one node to json*/
inline std::string NodeToJSON(int index);
/*! \brief Serialize one node to if-else statement*/
inline std::string NodeToIfElse(int index, bool is_predict_leaf_index);
/*! \brief Number of max leaves*/
int max_leaves_;
/*! \brief Number of current levas*/
......
......@@ -32,7 +32,7 @@ Application::Application(int argc, char** argv) {
if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads);
}
if (config_.io_config.data_filename.size() == 0) {
if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit");
}
}
......@@ -239,10 +239,13 @@ void Application::Train() {
}
// save model to file
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}
Log::Info("Finished training");
}
void Application::Predict() {
// create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
......@@ -258,6 +261,13 @@ void Application::InitPredict() {
Log::Info("Finished initializing prediction");
}
void Application::ConvertModel() {
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}
template<typename T>
T Application::GlobalSyncUpByMin(T& local) {
T global = local;
......
......@@ -700,6 +700,99 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str();
}
std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream str_buf;
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_iteration += boost_from_average_ ? 1 : 0;
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}
// PredictRaw
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, false) << std::endl;
}
str_buf << "double (*PredictTreePtr[])(const double*) = { ";
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << " , ";
}
str_buf << "PredictTree" << i;
}
str_buf << " };" << std::endl << std::endl;
std::stringstream pred_str_buf;
pred_str_buf << "\t" << "if (num_threads_ <= num_tree_per_iteration_) {" << std::endl;
pred_str_buf << "\t\t" << "#pragma omp parallel for schedule(static)" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "} else {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "double t = 0.0f;" << std::endl;
pred_str_buf << "\t\t\t" << "#pragma omp parallel for schedule(static) reduction(+:t)" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t\t\t" << "t += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] = t;" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;
str_buf << "void GBDT::PredictRaw(const double* features, double *output) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "}" << std::endl;
str_buf << std::endl;
// Predict
str_buf << "void GBDT::Predict(const double* features, double *output) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << std::endl;
// PredictLeafIndex
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, true) << std::endl;
}
str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << " , ";
}
str_buf << "PredictTree" << i << "Leaf";
}
str_buf << " };" << std::endl << std::endl;
str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
str_buf << "\t" << "#pragma omp parallel for schedule(static)" << std::endl;
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
return str_buf.str();
}
bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
output_file.open(filename);
output_file << ModelToIfElse(num_iteration);
output_file.close();
return (bool)output_file;
}
std::string GBDT::SaveModelToString(int num_iteration) const {
std::stringstream ss;
......
......@@ -144,15 +144,31 @@ public:
/*!
* \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \return Json format string of model
*/
std::string DumpModel(int num_iteration) const override;
/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \return if-else format codes of model
*/
std::string ModelToIfElse(int num_iteration) const override;
/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
bool SaveModelToIfElse(int num_iteration, const char* filename) const override;
/*!
* \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
virtual bool SaveModelToFile(int num_iterations, const char* filename) const override;
......
......@@ -35,6 +35,7 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types
GetInt(params, "num_threads", &num_threads);
GetString(params, "convert_model_language", &convert_model_language);
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
......@@ -129,6 +130,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
} else if (value == std::string("predict") || value == std::string("prediction")
|| value == std::string("test")) {
task_type = TaskType::kPredict;
} else if (value == std::string("convert_model")) {
task_type = TaskType::kConvertModel;
} else {
Log::Fatal("Unknown task type %s", value.c_str());
}
......@@ -210,6 +213,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model);
GetString(params, "convert_model", &convert_model);
GetString(params, "output_result", &output_result);
std::string tmp_str = "";
if (GetString(params, "valid_data", &tmp_str)) {
......
......@@ -368,6 +368,54 @@ std::string Tree::NodeToJSON(int index) {
return str_buf.str();
}
std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) {
std::stringstream str_buf;
str_buf << "double PredictTree" << index;
if (is_predict_leaf_index) {
str_buf << "Leaf";
}
str_buf << "(const double* arr) { ";
if (num_leaves_ == 1) {
str_buf << "return 0";
} else {
str_buf << NodeToIfElse(0, is_predict_leaf_index);
}
str_buf << " }" << std::endl;
return str_buf.str();
}
std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) {
std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
if (index >= 0) {
// non-leaf
str_buf << "if ( arr[" << split_feature_[index] << "] ";
if (decision_type_[index] == 0) {
str_buf << "<";
} else {
str_buf << "=";
}
str_buf << "= " << threshold_[index] << " ) { ";
// left subtree
str_buf << NodeToIfElse(left_child_[index], is_predict_leaf_index);
str_buf << " } else { ";
// right subtree
str_buf << NodeToIfElse(right_child_[index], is_predict_leaf_index);
str_buf << " }";
} else {
// leaf
str_buf << "return ";
if (is_predict_leaf_index) {
str_buf << ~index;
} else {
str_buf << leaf_value_[~index];
}
str_buf << ";";
}
return str_buf.str();
}
Tree::Tree(const std::string& str) {
std::vector<std::string> lines = Common::Split(str.c_str(), '\n');
std::unordered_map<std::string, std::string> key_vals;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment