Unverified Commit f2143cf7 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix the name of custom objective function (#1234)

* fix the name of custom objective function

* fix multi-class alias

* Update GPU-Windows.rst
parent 6522f538
...@@ -578,7 +578,7 @@ And open an issue in GitHub `here`_ with that log. ...@@ -578,7 +578,7 @@ And open an issue in GitHub `here`_ with that log.
.. _Boost: http://www.boost.org/users/history/version_1_63_0.html .. _Boost: http://www.boost.org/users/history/version_1_63_0.html
.. _link: https://git-for-windows.github.io/ .. _link: https://git-scm.com/download/win
.. _CMake: https://cmake.org/download/ .. _CMake: https://cmake.org/download/
......
...@@ -192,12 +192,20 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -192,12 +192,20 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
} }
} }
bool CheckMultiClassObjective(const std::string& objective_type) {
return (objective_type == std::string("multiclass")
|| objective_type == std::string("multiclassova")
|| objective_type == std::string("softmax")
|| objective_type == std::string("multiclass_ova")
|| objective_type == std::string("ova")
|| objective_type == std::string("ovr"));
}
void OverallConfig::CheckParamConflict() { void OverallConfig::CheckParamConflict() {
// check if objective_type, metric_type, and num_class match // check if objective_type, metric_type, and num_class match
int num_class_check = boosting_config.num_class; int num_class_check = boosting_config.num_class;
bool objective_type_multiclass = (objective_type == std::string("multiclass") bool objective_custom = objective_type == std::string("none") || objective_type == std::string("null") || objective_type == std::string("custom");
|| objective_type == std::string("multiclassova") bool objective_type_multiclass = CheckMultiClassObjective(objective_type) || (objective_custom && num_class_check > 1);
|| (objective_type == std::string("none") && num_class_check > 1));
if (objective_type_multiclass) { if (objective_type_multiclass) {
if (num_class_check <= 1) { if (num_class_check <= 1) {
...@@ -210,7 +218,8 @@ void OverallConfig::CheckParamConflict() { ...@@ -210,7 +218,8 @@ void OverallConfig::CheckParamConflict() {
} }
if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) { if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
for (std::string metric_type : metric_types) { for (std::string metric_type : metric_types) {
bool metric_type_multiclass = (metric_type == std::string("multi_logloss") bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
|| metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error")); || metric_type == std::string("multi_error"));
if ((objective_type_multiclass && !metric_type_multiclass) if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) { || (!objective_type_multiclass && metric_type_multiclass)) {
......
...@@ -40,8 +40,10 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -40,8 +40,10 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
return new RegressionGammaLoss(config); return new RegressionGammaLoss(config);
} else if (type == std::string("tweedie")) { } else if (type == std::string("tweedie")) {
return new RegressionTweedieLoss(config); return new RegressionTweedieLoss(config);
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom")) {
return nullptr;
} }
return nullptr; Log::Fatal("Unknown objective type name: %s", type.c_str());
} }
ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) { ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& str) {
...@@ -75,8 +77,10 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -75,8 +77,10 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
return new RegressionGammaLoss(strs); return new RegressionGammaLoss(strs);
} else if (type == std::string("tweedie")) { } else if (type == std::string("tweedie")) {
return new RegressionTweedieLoss(strs); return new RegressionTweedieLoss(strs);
} else if (type == std::string("none") || type == std::string("null") || type == std::string("custom")) {
return nullptr;
} }
return nullptr; Log::Fatal("Unknown objective type name: %s", type.c_str());
} }
} // namespace LightGBM } // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment