Commit cf8c8c0b authored by OMOTO Tsukasa's avatar OMOTO Tsukasa Committed by Guolin Ke
Browse files

Fix detecting alias of objective in loading from model strings (#2565)

Fix #2564
parent 3e3d76f6
......@@ -964,6 +964,59 @@ struct ParameterAlias {
}
};
inline std::string ParseObjectiveAlias(const std::string& type) {
if (type == std::string("regression") || type == std::string("regression_l2")
|| type == std::string("mean_squared_error") || type == std::string("mse") || type == std::string("l2")
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return "regression";
} else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error")
|| type == std::string("l1") || type == std::string("mae")) {
return "regression_l1";
} else if (type == std::string("multiclass") || type == std::string("softmax")) {
return "multiclass";
} else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
return "multiclassova";
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
return "cross_entropy";
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
return "cross_entropy_lambda";
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
return "mape";
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
return "custom";
}
return type;
}
inline std::string ParseMetricAlias(const std::string& type) {
if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return "l2";
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return "rmse";
} else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return "l1";
} else if (type == std::string("binary_logloss") || type == std::string("binary")) {
return "binary_logloss";
} else if (type == std::string("ndcg") || type == std::string("lambdarank")) {
return "ndcg";
} else if (type == std::string("map") || type == std::string("mean_average_precision")) {
return "map";
} else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("softmax") || type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
return "multi_logloss";
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
return "cross_entropy";
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
return "cross_entropy_lambda";
} else if (type == std::string("kldiv") || type == std::string("kullback_leibler")) {
return "kullback_leibler";
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
return "mape";
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
return "custom";
}
return type;
}
} // namespace LightGBM
#endif // LightGBM_CONFIG_H_
#endif // LightGBM_CONFIG_H_
\ No newline at end of file
......@@ -2,6 +2,7 @@
* Copyright (c) 2017 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#include <LightGBM/config.h>
#include <LightGBM/metric.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/utils/common.h>
......@@ -455,9 +456,10 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
if (key_vals.count("objective")) {
auto str = key_vals["objective"];
loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(ParseObjectiveAlias(str)));
objective_function_ = loaded_objective_.get();
}
if (!key_vals.count("tree_sizes")) {
while (p < end) {
auto line_len = Common::GetLine(p);
......
......@@ -66,59 +66,6 @@ void GetBoostingType(const std::unordered_map<std::string, std::string>& params,
}
}
std::string ParseObjectiveAlias(const std::string& type) {
if (type == std::string("regression") || type == std::string("regression_l2")
|| type == std::string("mean_squared_error") || type == std::string("mse") || type == std::string("l2")
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return "regression";
} else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error")
|| type == std::string("l1") || type == std::string("mae")) {
return "regression_l1";
} else if (type == std::string("multiclass") || type == std::string("softmax")) {
return "multiclass";
} else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
return "multiclassova";
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
return "cross_entropy";
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
return "cross_entropy_lambda";
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
return "mape";
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
return "custom";
}
return type;
}
std::string ParseMetricAlias(const std::string& type) {
if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return "l2";
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return "rmse";
} else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return "l1";
} else if (type == std::string("binary_logloss") || type == std::string("binary")) {
return "binary_logloss";
} else if (type == std::string("ndcg") || type == std::string("lambdarank")) {
return "ndcg";
} else if (type == std::string("map") || type == std::string("mean_average_precision")) {
return "map";
} else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("softmax") || type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) {
return "multi_logloss";
} else if (type == std::string("xentropy") || type == std::string("cross_entropy")) {
return "cross_entropy";
} else if (type == std::string("xentlambda") || type == std::string("cross_entropy_lambda")) {
return "cross_entropy_lambda";
} else if (type == std::string("kldiv") || type == std::string("kullback_leibler")) {
return "kullback_leibler";
} else if (type == std::string("mean_absolute_percentage_error") || type == std::string("mape")) {
return "mape";
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom") || type == std::string("na")) {
return "custom";
}
return type;
}
void ParseMetrics(const std::string& value, std::vector<std::string>* out_metric) {
std::unordered_set<std::string> metric_sets;
out_metric->clear();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment