Commit c05cfa89 authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for config.

parent 43d50370
......@@ -59,7 +59,7 @@ if [[ ${TASK} == "gpu" ]]; then
export PATH="$AMDAPPSDK/include/:$PATH"
export BOOST_ROOT="$HOME/miniconda/"
LGB_VER=$(head -n 1 VERSION.txt)
sed -i 's/std::string device_type = "cpu";/std::string device_type = "gpu";/' ../include/LightGBM/config.h
sed -i 's/const std::string kDefaultDevice = "cpu";/const std::string kDefaultDevice = "gpu";/' ../include/LightGBM/config.h
cd $TRAVIS_BUILD_DIR/python-package && python setup.py sdist || exit -1
cd $TRAVIS_BUILD_DIR/python-package/dist && pip install lightgbm-$LGB_VER.tar.gz -v --install-option=--gpu || exit -1
cd $TRAVIS_BUILD_DIR && pytest tests/python_package_test || exit -1
......@@ -73,7 +73,7 @@ if [[ ${TASK} == "mpi" ]]; then
cmake -DUSE_MPI=ON ..
elif [[ ${TASK} == "gpu" ]]; then
cmake -DUSE_GPU=ON -DBOOST_ROOT="$HOME/miniconda/" -DOpenCL_INCLUDE_DIR=$AMDAPPSDK/include/ ..
sed -i 's/std::string device_type = "cpu";/std::string device_type = "gpu";/' ../include/LightGBM/config.h
sed -i 's/const std::string kDefaultDevice = "cpu";/const std::string kDefaultDevice = "gpu";/' ../include/LightGBM/config.h
else
cmake ..
fi
......
......@@ -16,6 +16,11 @@
namespace LightGBM {
const std::string kDefaultTreeLearnerType = "serial";
const std::string kDefaultDevice = "cpu";
const std::string kDefaultBoostingType = "gbdt";
const std::string kDefaultObjectiveType = "regression";
/*!
* \brief The interface for Config
*/
......@@ -38,7 +43,7 @@ public:
* \param out Value will assign to out if key exists
* \return True if key exists
*/
inline bool GetString(
inline static bool GetString(
const std::unordered_map<std::string, std::string>& params,
const std::string& name, std::string* out);
......@@ -49,7 +54,7 @@ public:
* \param out Value will assign to out if key exists
* \return True if key exists
*/
inline bool GetInt(
inline static bool GetInt(
const std::unordered_map<std::string, std::string>& params,
const std::string& name, int* out);
......@@ -60,7 +65,7 @@ public:
* \param out Value will assign to out if key exists
* \return True if key exists
*/
inline bool GetDouble(
inline static bool GetDouble(
const std::unordered_map<std::string, std::string>& params,
const std::string& name, double* out);
......@@ -71,7 +76,7 @@ public:
* \param out Value will assign to out if key exists
* \return True if key exists
*/
inline bool GetBool(
inline static bool GetBool(
const std::unordered_map<std::string, std::string>& params,
const std::string& name, bool* out);
......@@ -135,7 +140,7 @@ public:
* And add an prefix "name:" while using column name
* Note: when using Index, it doesn't count the label index */
std::string categorical_column = "";
std::string device_type = "cpu";
std::string device_type = kDefaultDevice;
/*! \brief Set to true if want to use early stop for the prediction */
bool pred_early_stop = false;
......@@ -145,9 +150,6 @@ public:
double pred_early_stop_margin = 10.0f;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private:
void GetDeviceType(const std::unordered_map<std::string,
std::string>& params);
};
/*! \brief Config for objective function */
......@@ -246,15 +248,10 @@ public:
double other_rate = 0.1f;
// only used for the regression. Will boost from the average labels.
bool boost_from_average = true;
std::string tree_learner_type = "serial";
std::string device_type = "cpu";
std::string tree_learner_type = kDefaultTreeLearnerType;
std::string device_type = kDefaultDevice;
TreeConfig tree_config;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private:
void GetTreeLearnerType(const std::unordered_map<std::string,
std::string>& params);
void GetDeviceType(const std::unordered_map<std::string,
std::string>& params);
};
/*! \brief Config for Network */
......@@ -278,25 +275,16 @@ public:
bool is_parallel = false;
bool is_parallel_find_bin = false;
IOConfig io_config;
std::string boosting_type = "gbdt";
std::string boosting_type = kDefaultBoostingType;
BoostingConfig boosting_config;
std::string objective_type = "regression";
std::string objective_type = kDefaultObjectiveType;
ObjectiveConfig objective_config;
std::vector<std::string> metric_types;
MetricConfig metric_config;
std::string convert_model_language = "";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private:
void GetBoostingType(const std::unordered_map<std::string, std::string>& params);
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params);
void GetMetricType(const std::unordered_map<std::string, std::string>& params);
void GetTaskType(const std::unordered_map<std::string, std::string>& params);
void CheckParamConflict();
};
......
......@@ -32,49 +32,10 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
return params;
}
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types
GetInt(params, "num_threads", &num_threads);
GetString(params, "convert_model_language", &convert_model_language);
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
Random rand(seed);
int int_max = std::numeric_limits<short>::max();
io_config.data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
}
GetTaskType(params);
GetBoostingType(params);
GetObjectiveType(params);
GetMetricType(params);
// sub-config setup
network_config.Set(params);
io_config.Set(params);
boosting_config.Set(params);
objective_config.Set(params);
metric_config.Set(params);
// check for conflicts
CheckParamConflict();
if (io_config.verbosity == 1) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
} else if (io_config.verbosity == 0) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
} else if (io_config.verbosity >= 2) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
} else {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
}
}
void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
std::string GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "boosting_type", &value)) {
std::string boosting_type = kDefaultBoostingType;
if (ConfigBase::GetString(params, "boosting_type", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = "gbdt";
......@@ -86,19 +47,23 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
Log::Fatal("Unknown boosting type %s", value.c_str());
}
}
return boosting_type;
}
void OverallConfig::GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
std::string GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "objective", &value)) {
std::string objective_type = kDefaultObjectiveType;
if (ConfigBase::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
objective_type = value;
}
return objective_type;
}
void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::string>& params) {
std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "metric", &value)) {
std::vector<std::string> metric_types;
if (ConfigBase::GetString(params, "metric", &value)) {
// clear old metrics
metric_types.clear();
// to lower
......@@ -118,12 +83,13 @@ void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::str
}
metric_types.shrink_to_fit();
}
return metric_types;
}
void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::string>& params) {
TaskType GetTaskType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "task", &value)) {
TaskType task_type = TaskType::kTrain;
if (ConfigBase::GetString(params, "task", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("train") || value == std::string("training")) {
task_type = TaskType::kTrain;
......@@ -136,10 +102,88 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
Log::Fatal("Unknown task type %s", value.c_str());
}
}
return task_type;
}
void OverallConfig::CheckParamConflict() {
std::string GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
std::string device_type = kDefaultDevice;
if (ConfigBase::GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
return device_type;
}
std::string GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
std::string tree_learner_type = kDefaultTreeLearnerType;
if (ConfigBase::GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) {
tree_learner_type = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = "voting";
} else {
Log::Fatal("Unknown tree learner type %s", value.c_str());
}
}
return tree_learner_type;
}
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types
GetInt(params, "num_threads", &num_threads);
GetString(params, "convert_model_language", &convert_model_language);
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
Random rand(seed);
int int_max = std::numeric_limits<short>::max();
io_config.data_random_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.bagging_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
}
task_type = GetTaskType(params);
boosting_type = GetBoostingType(params);
metric_types = GetMetricType(params);
// sub-config setup
network_config.Set(params);
io_config.Set(params);
boosting_config.Set(params);
objective_type = GetObjectiveType(params);
objective_config.Set(params);
metric_config.Set(params);
// check for conflicts
CheckParamConflict();
if (io_config.verbosity == 1) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Info);
} else if (io_config.verbosity == 0) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Warning);
} else if (io_config.verbosity >= 2) {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Debug);
} else {
LightGBM::Log::ResetLogLevel(LightGBM::LogLevel::Fatal);
}
}
void OverallConfig::CheckParamConflict() {
// check if objective_type, metric_type, and num_class match
bool objective_type_multiclass = (objective_type == std::string("multiclass")
|| objective_type == std::string("multiclassova"));
......@@ -171,13 +215,14 @@ void OverallConfig::CheckParamConflict() {
boosting_config.tree_learner_type = "serial";
}
if (boosting_config.tree_learner_type == std::string("serial")) {
bool is_single_tree_learner = boosting_config.tree_learner_type == std::string("serial");
if (is_single_tree_learner) {
is_parallel = false;
network_config.num_machines = 1;
}
if (boosting_config.tree_learner_type == std::string("serial")
|| boosting_config.tree_learner_type == std::string("feature")) {
if (is_single_tree_learner || boosting_config.tree_learner_type == std::string("feature")) {
is_parallel_find_bin = false;
} else if (boosting_config.tree_learner_type == std::string("data")
|| boosting_config.tree_learner_type == std::string("voting")) {
......@@ -189,7 +234,6 @@ void OverallConfig::CheckParamConflict() {
// Change pool size to -1 (no limit) when using data parallel to reduce communication costs
boosting_config.tree_config.histogram_pool_size = -1;
}
}
}
......@@ -235,21 +279,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetDeviceType(params);
}
void IOConfig::GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
device_type = GetDeviceType(params);
}
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
......@@ -336,7 +366,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetBool(params, "use_missing", &use_missing);
}
void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_iterations", &num_iterations);
GetDouble(params, "sigmoid", &sigmoid);
......@@ -365,42 +394,12 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool(params, "boost_from_average", &boost_from_average);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetDeviceType(params);
GetTreeLearnerType(params);
device_type = GetDeviceType(params);
tree_learner_type = GetTreeLearnerType(params);
tree_config.Set(params);
}
void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) {
tree_learner_type = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = "voting";
} else {
Log::Fatal("Unknown tree learner type %s", value.c_str());
}
}
}
void BoostingConfig::GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
}
void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_machines", &num_machines);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment