Commit c05cfa89 authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for config.

parent 43d50370
...@@ -59,7 +59,7 @@ if [[ ${TASK} == "gpu" ]]; then ...@@ -59,7 +59,7 @@ if [[ ${TASK} == "gpu" ]]; then
export PATH="$AMDAPPSDK/include/:$PATH" export PATH="$AMDAPPSDK/include/:$PATH"
export BOOST_ROOT="$HOME/miniconda/" export BOOST_ROOT="$HOME/miniconda/"
LGB_VER=$(head -n 1 VERSION.txt) LGB_VER=$(head -n 1 VERSION.txt)
sed -i 's/std::string device_type = "cpu";/std::string device_type = "gpu";/' ../include/LightGBM/config.h sed -i 's/const std::string kDefaultDevice = "cpu";/const std::string kDefaultDevice = "gpu";/' ../include/LightGBM/config.h
cd $TRAVIS_BUILD_DIR/python-package && python setup.py sdist || exit -1 cd $TRAVIS_BUILD_DIR/python-package && python setup.py sdist || exit -1
cd $TRAVIS_BUILD_DIR/python-package/dist && pip install lightgbm-$LGB_VER.tar.gz -v --install-option=--gpu || exit -1 cd $TRAVIS_BUILD_DIR/python-package/dist && pip install lightgbm-$LGB_VER.tar.gz -v --install-option=--gpu || exit -1
cd $TRAVIS_BUILD_DIR && pytest tests/python_package_test || exit -1 cd $TRAVIS_BUILD_DIR && pytest tests/python_package_test || exit -1
...@@ -73,7 +73,7 @@ if [[ ${TASK} == "mpi" ]]; then ...@@ -73,7 +73,7 @@ if [[ ${TASK} == "mpi" ]]; then
cmake -DUSE_MPI=ON .. cmake -DUSE_MPI=ON ..
elif [[ ${TASK} == "gpu" ]]; then elif [[ ${TASK} == "gpu" ]]; then
cmake -DUSE_GPU=ON -DBOOST_ROOT="$HOME/miniconda/" -DOpenCL_INCLUDE_DIR=$AMDAPPSDK/include/ .. cmake -DUSE_GPU=ON -DBOOST_ROOT="$HOME/miniconda/" -DOpenCL_INCLUDE_DIR=$AMDAPPSDK/include/ ..
sed -i 's/std::string device_type = "cpu";/std::string device_type = "gpu";/' ../include/LightGBM/config.h sed -i 's/const std::string kDefaultDevice = "cpu";/const std::string kDefaultDevice = "gpu";/' ../include/LightGBM/config.h
else else
cmake .. cmake ..
fi fi
......
...@@ -16,6 +16,11 @@ ...@@ -16,6 +16,11 @@
namespace LightGBM { namespace LightGBM {
const std::string kDefaultTreeLearnerType = "serial";
const std::string kDefaultDevice = "cpu";
const std::string kDefaultBoostingType = "gbdt";
const std::string kDefaultObjectiveType = "regression";
/*! /*!
* \brief The interface for Config * \brief The interface for Config
*/ */
...@@ -38,7 +43,7 @@ public: ...@@ -38,7 +43,7 @@ public:
* \param out Value will assign to out if key exists * \param out Value will assign to out if key exists
* \return True if key exists * \return True if key exists
*/ */
inline bool GetString( inline static bool GetString(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, std::string* out); const std::string& name, std::string* out);
...@@ -49,7 +54,7 @@ public: ...@@ -49,7 +54,7 @@ public:
* \param out Value will assign to out if key exists * \param out Value will assign to out if key exists
* \return True if key exists * \return True if key exists
*/ */
inline bool GetInt( inline static bool GetInt(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, int* out); const std::string& name, int* out);
...@@ -60,7 +65,7 @@ public: ...@@ -60,7 +65,7 @@ public:
* \param out Value will assign to out if key exists * \param out Value will assign to out if key exists
* \return True if key exists * \return True if key exists
*/ */
inline bool GetDouble( inline static bool GetDouble(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, double* out); const std::string& name, double* out);
...@@ -71,7 +76,7 @@ public: ...@@ -71,7 +76,7 @@ public:
* \param out Value will assign to out if key exists * \param out Value will assign to out if key exists
* \return True if key exists * \return True if key exists
*/ */
inline bool GetBool( inline static bool GetBool(
const std::unordered_map<std::string, std::string>& params, const std::unordered_map<std::string, std::string>& params,
const std::string& name, bool* out); const std::string& name, bool* out);
...@@ -135,7 +140,7 @@ public: ...@@ -135,7 +140,7 @@ public:
* And add an prefix "name:" while using column name * And add an prefix "name:" while using column name
* Note: when using Index, it doesn't count the label index */ * Note: when using Index, it doesn't count the label index */
std::string categorical_column = ""; std::string categorical_column = "";
std::string device_type = "cpu"; std::string device_type = kDefaultDevice;
/*! \brief Set to true if want to use early stop for the prediction */ /*! \brief Set to true if want to use early stop for the prediction */
bool pred_early_stop = false; bool pred_early_stop = false;
...@@ -145,9 +150,6 @@ public: ...@@ -145,9 +150,6 @@ public:
double pred_early_stop_margin = 10.0f; double pred_early_stop_margin = 10.0f;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private:
void GetDeviceType(const std::unordered_map<std::string,
std::string>& params);
}; };
/*! \brief Config for objective function */ /*! \brief Config for objective function */
...@@ -246,15 +248,10 @@ public: ...@@ -246,15 +248,10 @@ public:
double other_rate = 0.1f; double other_rate = 0.1f;
// only used for the regression. Will boost from the average labels. // only used for the regression. Will boost from the average labels.
bool boost_from_average = true; bool boost_from_average = true;
std::string tree_learner_type = "serial"; std::string tree_learner_type = kDefaultTreeLearnerType;
std::string device_type = "cpu"; std::string device_type = kDefaultDevice;
TreeConfig tree_config; TreeConfig tree_config;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private:
void GetTreeLearnerType(const std::unordered_map<std::string,
std::string>& params);
void GetDeviceType(const std::unordered_map<std::string,
std::string>& params);
}; };
/*! \brief Config for Network */ /*! \brief Config for Network */
...@@ -278,25 +275,16 @@ public: ...@@ -278,25 +275,16 @@ public:
bool is_parallel = false; bool is_parallel = false;
bool is_parallel_find_bin = false; bool is_parallel_find_bin = false;
IOConfig io_config; IOConfig io_config;
std::string boosting_type = "gbdt"; std::string boosting_type = kDefaultBoostingType;
BoostingConfig boosting_config; BoostingConfig boosting_config;
std::string objective_type = "regression"; std::string objective_type = kDefaultObjectiveType;
ObjectiveConfig objective_config; ObjectiveConfig objective_config;
std::vector<std::string> metric_types; std::vector<std::string> metric_types;
MetricConfig metric_config; MetricConfig metric_config;
std::string convert_model_language = ""; std::string convert_model_language = "";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private: private:
void GetBoostingType(const std::unordered_map<std::string, std::string>& params);
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params);
void GetMetricType(const std::unordered_map<std::string, std::string>& params);
void GetTaskType(const std::unordered_map<std::string, std::string>& params);
void CheckParamConflict(); void CheckParamConflict();
}; };
......
...@@ -32,49 +32,10 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par ...@@ -32,49 +32,10 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
return params; return params;
} }
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) { std::string GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
// load main config types
GetInt(params, "num_threads", &num_threads);
GetString(params, "convert_model_language", &convert_model_language);
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
Random rand(seed);
int int_max = std::numeric_limits<short>::max();
io_config.data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
}
GetTaskType(params);
GetBoostingType(params);
GetObjectiveType(params);
GetMetricType(params);
// sub-config setup
network_config.Set(params);
io_config.Set(params);
boosting_config.Set(params);
objective_config.Set(params);
metric_config.Set(params);
// check for conflicts
CheckParamConflict();
if (io_config.verbosity == 1) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
} else if (io_config.verbosity == 0) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
} else if (io_config.verbosity >= 2) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
} else {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
}
}
void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
std::string value; std::string value;
if (GetString(params, "boosting_type", &value)) { std::string boosting_type = kDefaultBoostingType;
if (ConfigBase::GetString(params, "boosting_type", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) { if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = "gbdt"; boosting_type = "gbdt";
...@@ -86,19 +47,23 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s ...@@ -86,19 +47,23 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
Log::Fatal("Unknown boosting type %s", value.c_str()); Log::Fatal("Unknown boosting type %s", value.c_str());
} }
} }
return boosting_type;
} }
void OverallConfig::GetObjectiveType(const std::unordered_map<std::string, std::string>& params) { std::string GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
std::string value; std::string value;
if (GetString(params, "objective", &value)) { std::string objective_type = kDefaultObjectiveType;
if (ConfigBase::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
objective_type = value; objective_type = value;
} }
return objective_type;
} }
void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::string>& params) { std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std::string>& params) {
std::string value; std::string value;
if (GetString(params, "metric", &value)) { std::vector<std::string> metric_types;
if (ConfigBase::GetString(params, "metric", &value)) {
// clear old metrics // clear old metrics
metric_types.clear(); metric_types.clear();
// to lower // to lower
...@@ -118,17 +83,18 @@ void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::str ...@@ -118,17 +83,18 @@ void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::str
} }
metric_types.shrink_to_fit(); metric_types.shrink_to_fit();
} }
return metric_types;
} }
TaskType GetTaskType(const std::unordered_map<std::string, std::string>& params) {
void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::string>& params) {
std::string value; std::string value;
if (GetString(params, "task", &value)) { TaskType task_type = TaskType::kTrain;
if (ConfigBase::GetString(params, "task", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("train") || value == std::string("training")) { if (value == std::string("train") || value == std::string("training")) {
task_type = TaskType::kTrain; task_type = TaskType::kTrain;
} else if (value == std::string("predict") || value == std::string("prediction") } else if (value == std::string("predict") || value == std::string("prediction")
|| value == std::string("test")) { || value == std::string("test")) {
task_type = TaskType::kPredict; task_type = TaskType::kPredict;
} else if (value == std::string("convert_model")) { } else if (value == std::string("convert_model")) {
task_type = TaskType::kConvertModel; task_type = TaskType::kConvertModel;
...@@ -136,10 +102,88 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin ...@@ -136,10 +102,88 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
Log::Fatal("Unknown task type %s", value.c_str()); Log::Fatal("Unknown task type %s", value.c_str());
} }
} }
return task_type;
} }
void OverallConfig::CheckParamConflict() { std::string GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
std::string device_type = kDefaultDevice;
if (ConfigBase::GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
return device_type;
}
std::string GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
std::string tree_learner_type = kDefaultTreeLearnerType;
if (ConfigBase::GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) {
tree_learner_type = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = "voting";
} else {
Log::Fatal("Unknown tree learner type %s", value.c_str());
}
}
return tree_learner_type;
}
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types
GetInt(params, "num_threads", &num_threads);
GetString(params, "convert_model_language", &convert_model_language);
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
Random rand(seed);
int int_max = std::numeric_limits<short>::max();
io_config.data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
}
task_type = GetTaskType(params);
boosting_type = GetBoostingType(params);
metric_types = GetMetricType(params);
// sub-config setup
network_config.Set(params);
io_config.Set(params);
boosting_config.Set(params);
objective_type = GetObjectiveType(params);
objective_config.Set(params);
metric_config.Set(params);
// check for conflicts
CheckParamConflict();
if (io_config.verbosity == 1) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
} else if (io_config.verbosity == 0) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
} else if (io_config.verbosity >= 2) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
} else {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
}
}
void OverallConfig::CheckParamConflict() {
// check if objective_type, metric_type, and num_class match // check if objective_type, metric_type, and num_class match
bool objective_type_multiclass = (objective_type == std::string("multiclass") bool objective_type_multiclass = (objective_type == std::string("multiclass")
|| objective_type == std::string("multiclassova")); || objective_type == std::string("multiclassova"));
...@@ -163,7 +207,7 @@ void OverallConfig::CheckParamConflict() { ...@@ -163,7 +207,7 @@ void OverallConfig::CheckParamConflict() {
} }
} }
} }
if (network_config.num_machines > 1) { if (network_config.num_machines > 1) {
is_parallel = true; is_parallel = true;
} else { } else {
...@@ -171,13 +215,14 @@ void OverallConfig::CheckParamConflict() { ...@@ -171,13 +215,14 @@ void OverallConfig::CheckParamConflict() {
boosting_config.tree_learner_type = "serial"; boosting_config.tree_learner_type = "serial";
} }
if (boosting_config.tree_learner_type == std::string("serial")) { bool is_single_tree_learner = boosting_config.tree_learner_type == std::string("serial");
if (is_single_tree_learner) {
is_parallel = false; is_parallel = false;
network_config.num_machines = 1; network_config.num_machines = 1;
} }
if (boosting_config.tree_learner_type == std::string("serial") if (is_single_tree_learner || boosting_config.tree_learner_type == std::string("feature")) {
|| boosting_config.tree_learner_type == std::string("feature")) {
is_parallel_find_bin = false; is_parallel_find_bin = false;
} else if (boosting_config.tree_learner_type == std::string("data") } else if (boosting_config.tree_learner_type == std::string("data")
|| boosting_config.tree_learner_type == std::string("voting")) { || boosting_config.tree_learner_type == std::string("voting")) {
...@@ -189,7 +234,6 @@ void OverallConfig::CheckParamConflict() { ...@@ -189,7 +234,6 @@ void OverallConfig::CheckParamConflict() {
// Change pool size to -1 (no limit) when using data parallel to reduce communication costs // Change pool size to -1 (no limit) when using data parallel to reduce communication costs
boosting_config.tree_config.histogram_pool_size = -1; boosting_config.tree_config.histogram_pool_size = -1;
} }
} }
} }
...@@ -235,21 +279,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -235,21 +279,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq); GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin); GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetDeviceType(params); device_type = GetDeviceType(params);
}
void IOConfig::GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
} }
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) { void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...@@ -336,7 +366,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) ...@@ -336,7 +366,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetBool(params, "use_missing", &use_missing); GetBool(params, "use_missing", &use_missing);
} }
void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) { void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_iterations", &num_iterations); GetInt(params, "num_iterations", &num_iterations);
GetDouble(params, "sigmoid", &sigmoid); GetDouble(params, "sigmoid", &sigmoid);
...@@ -365,42 +394,12 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -365,42 +394,12 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool(params, "boost_from_average", &boost_from_average); GetBool(params, "boost_from_average", &boost_from_average);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0); CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0); CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetDeviceType(params); device_type = GetDeviceType(params);
GetTreeLearnerType(params); tree_learner_type = GetTreeLearnerType(params);
tree_config.Set(params); tree_config.Set(params);
} }
void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) {
tree_learner_type = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = "voting";
} else {
Log::Fatal("Unknown tree learner type %s", value.c_str());
}
}
}
void BoostingConfig::GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
}
void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) { void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_machines", &num_machines); GetInt(params, "num_machines", &num_machines);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment