Commit bbb528d7 authored by Guolin Ke's avatar Guolin Ke
Browse files

change default prediction type

parent d1e0bab5
...@@ -103,7 +103,8 @@ public: ...@@ -103,7 +103,8 @@ public:
bool is_save_binary_file = false; bool is_save_binary_file = false;
bool enable_load_from_binary_file = true; bool enable_load_from_binary_file = true;
int bin_construct_sample_cnt = 50000; int bin_construct_sample_cnt = 50000;
bool is_raw_score = true; bool is_predict_leaf_index = false;
bool is_predict_raw_score = false;
bool has_header = false; bool has_header = false;
/*! \brief Index or column name of label, default is the first column /*! \brief Index or column name of label, default is the first column
...@@ -224,7 +225,6 @@ public: ...@@ -224,7 +225,6 @@ public:
int num_threads = 0; int num_threads = 0;
bool is_parallel = false; bool is_parallel = false;
bool is_parallel_find_bin = false; bool is_parallel_find_bin = false;
bool predict_leaf_index = false;
IOConfig io_config; IOConfig io_config;
BoostingType boosting_type = BoostingType::kGBDT; BoostingType boosting_type = BoostingType::kGBDT;
BoostingConfig* boosting_config = nullptr; BoostingConfig* boosting_config = nullptr;
...@@ -365,7 +365,9 @@ struct ParameterAlias { ...@@ -365,7 +365,9 @@ struct ParameterAlias {
{ "query", "group_column" }, { "query", "group_column" },
{ "query_column", "group_column" }, { "query_column", "group_column" },
{ "ignore_feature", "ignore_column" }, { "ignore_feature", "ignore_column" },
{ "blacklist", "ignore_column" } { "blacklist", "ignore_column" },
{ "predict_raw_score", "is_predict_raw_score" },
{ "predict_leaf_index", "is_predict_leaf_index" }
}); });
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) { for (const auto& pair : *params) {
......
...@@ -256,8 +256,8 @@ void Application::Train() { ...@@ -256,8 +256,8 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
boosting_->SetNumUsedModel(config_.io_config.num_model_predict); boosting_->SetNumUsedModel(config_.io_config.num_model_predict);
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_raw_score, Predictor predictor(boosting_, config_.io_config.is_predict_raw_score,
config_.predict_leaf_index); config_.io_config.is_predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finished prediction"); Log::Info("Finished prediction");
......
...@@ -34,9 +34,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -34,9 +34,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
// load main config types // load main config types
GetInt(params, "num_threads", &num_threads); GetInt(params, "num_threads", &num_threads);
GetTaskType(params); GetTaskType(params);
GetBool(params, "predict_leaf_index", &predict_leaf_index);
GetBoostingType(params); GetBoostingType(params);
GetObjectiveType(params); GetObjectiveType(params);
GetMetricType(params); GetMetricType(params);
...@@ -195,7 +192,8 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -195,7 +192,8 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "use_two_round_loading", &use_two_round_loading); GetBool(params, "use_two_round_loading", &use_two_round_loading);
GetBool(params, "is_save_binary_file", &is_save_binary_file); GetBool(params, "is_save_binary_file", &is_save_binary_file);
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file); GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "is_raw_score", &is_raw_score); GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
GetString(params, "output_model", &output_model); GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
GetString(params, "output_result", &output_result); GetString(params, "output_result", &output_result);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment