Commit a6f47d00 authored by Guolin Ke's avatar Guolin Ke
Browse files

use std::string for tree_learner_type.

parent 9b2558d6
...@@ -187,12 +187,6 @@ public: ...@@ -187,12 +187,6 @@ public:
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
/*! \brief Types of tree learning algorithms */
enum TreeLearnerType {
kSerialTreeLearner, kFeatureParallelTreelearner,
kDataParallelTreeLearner, KVotingParallelTreeLearner
};
/*! \brief Config for Boosting */ /*! \brief Config for Boosting */
struct BoostingConfig: public ConfigBase { struct BoostingConfig: public ConfigBase {
public: public:
...@@ -213,7 +207,7 @@ public: ...@@ -213,7 +207,7 @@ public:
bool xgboost_dart_mode = false; bool xgboost_dart_mode = false;
bool uniform_drop = false; bool uniform_drop = false;
int drop_seed = 4; int drop_seed = 4;
TreeLearnerType tree_learner_type = TreeLearnerType::kSerialTreeLearner; std::string tree_learner_type = "serial";
TreeConfig tree_config; TreeConfig tree_config;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
private: private:
......
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
* \param type Type of tree learner * \param type Type of tree learner
* \param tree_config config of tree * \param tree_config config of tree
*/ */
static TreeLearner* CreateTreeLearner(TreeLearnerType type, static TreeLearner* CreateTreeLearner(const std::string& type,
const TreeConfig* tree_config); const TreeConfig* tree_config);
}; };
......
...@@ -159,20 +159,22 @@ void OverallConfig::CheckParamConflict() { ...@@ -159,20 +159,22 @@ void OverallConfig::CheckParamConflict() {
is_parallel = true; is_parallel = true;
} else { } else {
is_parallel = false; is_parallel = false;
boosting_config.tree_learner_type = TreeLearnerType::kSerialTreeLearner; boosting_config.tree_learner_type = "serial";
} }
if (boosting_config.tree_learner_type == TreeLearnerType::kSerialTreeLearner) { if (boosting_config.tree_learner_type == std::string("serial")) {
is_parallel = false; is_parallel = false;
network_config.num_machines = 1; network_config.num_machines = 1;
} }
if (boosting_config.tree_learner_type == TreeLearnerType::kSerialTreeLearner || if (boosting_config.tree_learner_type == std::string("serial")
boosting_config.tree_learner_type == TreeLearnerType::kFeatureParallelTreelearner) { || boosting_config.tree_learner_type == std::string("feature")) {
is_parallel_find_bin = false; is_parallel_find_bin = false;
} else if (boosting_config.tree_learner_type == TreeLearnerType::kDataParallelTreeLearner) { } else if (boosting_config.tree_learner_type == std::string("data")
|| boosting_config.tree_learner_type == std::string("voting")) {
is_parallel_find_bin = true; is_parallel_find_bin = true;
if (boosting_config.tree_config.histogram_pool_size >= 0) { if (boosting_config.tree_config.histogram_pool_size >= 0
&& boosting_config.tree_learner_type == std::string("data")) {
Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs" Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
, boosting_config.tree_config.histogram_pool_size); , boosting_config.tree_config.histogram_pool_size);
// Change pool size to -1 (not limit) when using data parallel to reduce communication costs // Change pool size to -1 (not limit) when using data parallel to reduce communication costs
...@@ -326,13 +328,13 @@ void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, st ...@@ -326,13 +328,13 @@ void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, st
if (GetString(params, "tree_learner", &value)) { if (GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) { if (value == std::string("serial")) {
tree_learner_type = TreeLearnerType::kSerialTreeLearner; tree_learner_type = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) { } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = TreeLearnerType::kFeatureParallelTreelearner; tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) { } else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = TreeLearnerType::kDataParallelTreeLearner; tree_learner_type = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) { } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = TreeLearnerType::KVotingParallelTreeLearner; tree_learner_type = "voting";
} else { } else {
Log::Fatal("Unknown tree learner type %s", value.c_str()); Log::Fatal("Unknown tree learner type %s", value.c_str());
} }
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
namespace LightGBM { namespace LightGBM {
TreeLearner* TreeLearner::CreateTreeLearner(TreeLearnerType type, const TreeConfig* tree_config) { TreeLearner* TreeLearner::CreateTreeLearner(const std::string& type, const TreeConfig* tree_config) {
if (type == TreeLearnerType::kSerialTreeLearner) { if (type == std::string("serial")) {
return new SerialTreeLearner(tree_config); return new SerialTreeLearner(tree_config);
} else if (type == TreeLearnerType::kFeatureParallelTreelearner) { } else if (type == std::string("feature")) {
return new FeatureParallelTreeLearner(tree_config); return new FeatureParallelTreeLearner(tree_config);
} else if (type == TreeLearnerType::kDataParallelTreeLearner) { } else if (type == std::string("data")) {
return new DataParallelTreeLearner(tree_config); return new DataParallelTreeLearner(tree_config);
} else if (type == TreeLearnerType::KVotingParallelTreeLearner) { } else if (type == std::string("voting")) {
return new VotingParallelTreeLearner(tree_config); return new VotingParallelTreeLearner(tree_config);
} }
return nullptr; return nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment