"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "2dfb9a40478b965db8325baa21a63d9281f96b7c"
Commit 9c57793e authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

refine set params (#933)

parent e66a8a3c
...@@ -32,42 +32,37 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par ...@@ -32,42 +32,37 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
return params; return params;
} }
std::string GetBoostingType(const std::unordered_map<std::string, std::string>& params) { void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting_type) {
std::string value; std::string value;
std::string boosting_type = kDefaultBoostingType;
if (ConfigBase::GetString(params, "boosting_type", &value)) { if (ConfigBase::GetString(params, "boosting_type", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) { if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = "gbdt"; *boosting_type = "gbdt";
} else if (value == std::string("dart")) { } else if (value == std::string("dart")) {
boosting_type = "dart"; *boosting_type = "dart";
} else if (value == std::string("goss")) { } else if (value == std::string("goss")) {
boosting_type = "goss"; *boosting_type = "goss";
} else if (value == std::string("rf") || value == std::string("randomforest")) { } else if (value == std::string("rf") || value == std::string("randomforest")) {
boosting_type = "rf"; *boosting_type = "rf";
} else { } else {
Log::Fatal("Unknown boosting type %s", value.c_str()); Log::Fatal("Unknown boosting type %s", value.c_str());
} }
} }
return boosting_type;
} }
std::string GetObjectiveType(const std::unordered_map<std::string, std::string>& params) { void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective_type) {
std::string value; std::string value;
std::string objective_type = kDefaultObjectiveType;
if (ConfigBase::GetString(params, "objective", &value)) { if (ConfigBase::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
objective_type = value; *objective_type = value;
} }
return objective_type;
} }
std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std::string>& params) { void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric_types) {
std::string value; std::string value;
std::vector<std::string> metric_types;
if (ConfigBase::GetString(params, "metric", &value)) { if (ConfigBase::GetString(params, "metric", &value)) {
// clear old metrics // clear old metrics
metric_types.clear(); metric_types->clear();
// to lower // to lower
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
// split // split
...@@ -81,66 +76,59 @@ std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std ...@@ -81,66 +76,59 @@ std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std
} }
} }
for (auto& metric : metric_sets) { for (auto& metric : metric_sets) {
metric_types.push_back(metric); metric_types->push_back(metric);
} }
metric_types.shrink_to_fit(); metric_types->shrink_to_fit();
} }
return metric_types;
} }
TaskType GetTaskType(const std::unordered_map<std::string, std::string>& params) { void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) {
std::string value; std::string value;
TaskType task_type = TaskType::kTrain;
if (ConfigBase::GetString(params, "task", &value)) { if (ConfigBase::GetString(params, "task", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("train") || value == std::string("training")) { if (value == std::string("train") || value == std::string("training")) {
task_type = TaskType::kTrain; *task_type = TaskType::kTrain;
} else if (value == std::string("predict") || value == std::string("prediction") } else if (value == std::string("predict") || value == std::string("prediction")
|| value == std::string("test")) { || value == std::string("test")) {
task_type = TaskType::kPredict; *task_type = TaskType::kPredict;
} else if (value == std::string("convert_model")) { } else if (value == std::string("convert_model")) {
task_type = TaskType::kConvertModel; *task_type = TaskType::kConvertModel;
} else { } else {
Log::Fatal("Unknown task type %s", value.c_str()); Log::Fatal("Unknown task type %s", value.c_str());
} }
} }
return task_type;
} }
std::string GetDeviceType(const std::unordered_map<std::string, std::string>& params) { void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
std::string value; std::string value;
std::string device_type = kDefaultDevice;
if (ConfigBase::GetString(params, "device", &value)) { if (ConfigBase::GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) { if (value == std::string("cpu")) {
device_type = "cpu"; *device_type = "cpu";
} else if (value == std::string("gpu")) { } else if (value == std::string("gpu")) {
device_type = "gpu"; *device_type = "gpu";
} else { } else {
Log::Fatal("Unknown device type %s", value.c_str()); Log::Fatal("Unknown device type %s", value.c_str());
} }
} }
return device_type;
} }
std::string GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) { void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner_type) {
std::string value; std::string value;
std::string tree_learner_type = kDefaultTreeLearnerType;
if (ConfigBase::GetString(params, "tree_learner", &value)) { if (ConfigBase::GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower); std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) { if (value == std::string("serial")) {
tree_learner_type = "serial"; *tree_learner_type = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) { } else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = "feature"; *tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) { } else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = "data"; *tree_learner_type = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) { } else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = "voting"; *tree_learner_type = "voting";
} else { } else {
Log::Fatal("Unknown tree learner type %s", value.c_str()); Log::Fatal("Unknown tree learner type %s", value.c_str());
} }
} }
return tree_learner_type;
} }
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) { void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...@@ -157,17 +145,17 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -157,17 +145,17 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max)); boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max)); boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
} }
task_type = GetTaskType(params); GetTaskType(params, &task_type);
boosting_type = GetBoostingType(params); GetBoostingType(params, &boosting_type);
metric_types = GetMetricType(params); GetMetricType(params, &metric_types);
// sub-config setup // sub-config setup
network_config.Set(params); network_config.Set(params);
io_config.Set(params); io_config.Set(params);
boosting_config.Set(params); boosting_config.Set(params);
objective_type = GetObjectiveType(params); GetObjectiveType(params, &objective_type);
objective_config.Set(params); objective_config.Set(params);
metric_config.Set(params); metric_config.Set(params);
...@@ -298,7 +286,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -298,7 +286,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin); GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetBool(params, "use_missing", &use_missing); GetBool(params, "use_missing", &use_missing);
GetBool(params, "zero_as_missing", &zero_as_missing); GetBool(params, "zero_as_missing", &zero_as_missing);
device_type = GetDeviceType(params); GetDeviceType(params, &device_type);
} }
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) { void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...@@ -413,8 +401,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -413,8 +401,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool(params, "boost_from_average", &boost_from_average); GetBool(params, "boost_from_average", &boost_from_average);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0); CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0); CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
device_type = GetDeviceType(params); GetDeviceType(params, &device_type);
tree_learner_type = GetTreeLearnerType(params); GetTreeLearnerType(params, &tree_learner_type);
tree_config.Set(params); tree_config.Set(params);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment