"vscode:/vscode.git/clone" did not exist on "b33a12ea3883f306388e69f12ceb421b1ee7ec29"
Commit fc167922 authored by Guolin Ke's avatar Guolin Ke
Browse files

support use objective name to create metric.

parent 169a2712
......@@ -90,6 +90,13 @@ void GetMetricType(const std::unordered_map<std::string, std::string>& params, s
}
metric_types->shrink_to_fit();
}
// add names of objective function if not providing metric
if (metric_types->empty()) {
if (ConfigBase::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
metric_types->push_back(value);
}
}
}
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) {
......
......@@ -9,11 +9,11 @@
namespace LightGBM {
Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) {
if (type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return new L2Metric(config);
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RMSEMetric(config);
} else if (type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
} else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new L1Metric(config);
} else if (type == std::string("quantile")) {
return new QuantileMetric(config);
......@@ -23,7 +23,7 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new FairLossMetric(config);
} else if (type == std::string("poisson")) {
return new PoissonMetric(config);
} else if (type == std::string("binary_logloss")) {
} else if (type == std::string("binary_logloss") || type == std::string("binary")) {
return new BinaryLoglossMetric(config);
} else if (type == std::string("binary_error")) {
return new BinaryErrorMetric(config);
......@@ -33,7 +33,7 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new NDCGMetric(config);
} else if (type == std::string("map") || type == std::string("mean_average_precision")) {
return new MapMetric(config);
} else if (type == std::string("multi_logloss")) {
} else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("multiclass_ova")) {
return new MultiSoftmaxLoglossMetric(config);
} else if (type == std::string("multi_error")) {
return new MultiErrorMetric(config);
......
......@@ -9,7 +9,8 @@ namespace LightGBM {
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const ObjectiveConfig& config) {
if (type == std::string("regression") || type == std::string("regression_l2")
|| type == std::string("mean_squared_error") || type == std::string("mse")) {
|| type == std::string("mean_squared_error") || type == std::string("mse")
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RegressionL2loss(config);
} else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new RegressionL1loss(config);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment