Commit fc167922 authored by Guolin Ke's avatar Guolin Ke
Browse files

support use objective name to create metric.

parent 169a2712
...@@ -90,6 +90,13 @@ void GetMetricType(const std::unordered_map<std::string, std::string>& params, s ...@@ -90,6 +90,13 @@ void GetMetricType(const std::unordered_map<std::string, std::string>& params, s
} }
metric_types->shrink_to_fit(); metric_types->shrink_to_fit();
} }
// add names of objective function if not providing metric
if (metric_types->empty()) {
if (ConfigBase::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
metric_types->push_back(value);
}
}
} }
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) { void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) {
......
...@@ -9,11 +9,11 @@ ...@@ -9,11 +9,11 @@
namespace LightGBM { namespace LightGBM {
Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) { Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) {
if (type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) { if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return new L2Metric(config); return new L2Metric(config);
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) { } else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RMSEMetric(config); return new RMSEMetric(config);
} else if (type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) { } else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new L1Metric(config); return new L1Metric(config);
} else if (type == std::string("quantile")) { } else if (type == std::string("quantile")) {
return new QuantileMetric(config); return new QuantileMetric(config);
...@@ -23,7 +23,7 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config ...@@ -23,7 +23,7 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new FairLossMetric(config); return new FairLossMetric(config);
} else if (type == std::string("poisson")) { } else if (type == std::string("poisson")) {
return new PoissonMetric(config); return new PoissonMetric(config);
} else if (type == std::string("binary_logloss")) { } else if (type == std::string("binary_logloss") || type == std::string("binary")) {
return new BinaryLoglossMetric(config); return new BinaryLoglossMetric(config);
} else if (type == std::string("binary_error")) { } else if (type == std::string("binary_error")) {
return new BinaryErrorMetric(config); return new BinaryErrorMetric(config);
...@@ -33,7 +33,7 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config ...@@ -33,7 +33,7 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new NDCGMetric(config); return new NDCGMetric(config);
} else if (type == std::string("map") || type == std::string("mean_average_precision")) { } else if (type == std::string("map") || type == std::string("mean_average_precision")) {
return new MapMetric(config); return new MapMetric(config);
} else if (type == std::string("multi_logloss")) { } else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("multiclass_ova")) {
return new MultiSoftmaxLoglossMetric(config); return new MultiSoftmaxLoglossMetric(config);
} else if (type == std::string("multi_error")) { } else if (type == std::string("multi_error")) {
return new MultiErrorMetric(config); return new MultiErrorMetric(config);
......
...@@ -9,7 +9,8 @@ namespace LightGBM { ...@@ -9,7 +9,8 @@ namespace LightGBM {
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const ObjectiveConfig& config) { ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const ObjectiveConfig& config) {
if (type == std::string("regression") || type == std::string("regression_l2") if (type == std::string("regression") || type == std::string("regression_l2")
|| type == std::string("mean_squared_error") || type == std::string("mse")) { || type == std::string("mean_squared_error") || type == std::string("mse")
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RegressionL2loss(config); return new RegressionL2loss(config);
} else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) { } else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new RegressionL1loss(config); return new RegressionL1loss(config);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment