"include/LightGBM/vscode:/vscode.git/clone" did not exist on "08be1e97bb7755944f5d5e1de67920c01503c76d"
Commit 9c57793e authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

refine set params (#933)

parent e66a8a3c
......@@ -32,42 +32,37 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
return params;
}
std::string GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting_type) {
std::string value;
std::string boosting_type = kDefaultBoostingType;
if (ConfigBase::GetString(params, "boosting_type", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = "gbdt";
*boosting_type = "gbdt";
} else if (value == std::string("dart")) {
boosting_type = "dart";
*boosting_type = "dart";
} else if (value == std::string("goss")) {
boosting_type = "goss";
*boosting_type = "goss";
} else if (value == std::string("rf") || value == std::string("randomforest")) {
boosting_type = "rf";
*boosting_type = "rf";
} else {
Log::Fatal("Unknown boosting type %s", value.c_str());
}
}
return boosting_type;
}
std::string GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective_type) {
std::string value;
std::string objective_type = kDefaultObjectiveType;
if (ConfigBase::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
objective_type = value;
*objective_type = value;
}
return objective_type;
}
std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std::string>& params) {
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric_types) {
std::string value;
std::vector<std::string> metric_types;
if (ConfigBase::GetString(params, "metric", &value)) {
// clear old metrics
metric_types.clear();
metric_types->clear();
// to lower
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
// split
......@@ -81,66 +76,59 @@ std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std
}
}
for (auto& metric : metric_sets) {
metric_types.push_back(metric);
metric_types->push_back(metric);
}
metric_types.shrink_to_fit();
metric_types->shrink_to_fit();
}
return metric_types;
}
TaskType GetTaskType(const std::unordered_map<std::string, std::string>& params) {
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) {
std::string value;
TaskType task_type = TaskType::kTrain;
if (ConfigBase::GetString(params, "task", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("train") || value == std::string("training")) {
task_type = TaskType::kTrain;
*task_type = TaskType::kTrain;
} else if (value == std::string("predict") || value == std::string("prediction")
|| value == std::string("test")) {
task_type = TaskType::kPredict;
*task_type = TaskType::kPredict;
} else if (value == std::string("convert_model")) {
task_type = TaskType::kConvertModel;
*task_type = TaskType::kConvertModel;
} else {
Log::Fatal("Unknown task type %s", value.c_str());
}
}
return task_type;
}
std::string GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
std::string value;
std::string device_type = kDefaultDevice;
if (ConfigBase::GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
*device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
*device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
return device_type;
}
std::string GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner_type) {
std::string value;
std::string tree_learner_type = kDefaultTreeLearnerType;
if (ConfigBase::GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) {
tree_learner_type = "serial";
*tree_learner_type = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = "feature";
*tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = "data";
*tree_learner_type = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = "voting";
*tree_learner_type = "voting";
} else {
Log::Fatal("Unknown tree learner type %s", value.c_str());
}
}
return tree_learner_type;
}
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
......@@ -157,17 +145,17 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
}
task_type = GetTaskType(params);
boosting_type = GetBoostingType(params);
GetTaskType(params, &task_type);
GetBoostingType(params, &boosting_type);
metric_types = GetMetricType(params);
GetMetricType(params, &metric_types);
// sub-config setup
network_config.Set(params);
io_config.Set(params);
boosting_config.Set(params);
objective_type = GetObjectiveType(params);
GetObjectiveType(params, &objective_type);
objective_config.Set(params);
metric_config.Set(params);
......@@ -298,7 +286,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetBool(params, "use_missing", &use_missing);
GetBool(params, "zero_as_missing", &zero_as_missing);
device_type = GetDeviceType(params);
GetDeviceType(params, &device_type);
}
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
......@@ -413,8 +401,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool(params, "boost_from_average", &boost_from_average);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
device_type = GetDeviceType(params);
tree_learner_type = GetTreeLearnerType(params);
GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner_type);
tree_config.Set(params);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment