Unverified Commit 7b6f80f3 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix custom metric for multiclass (#1505)

* fix custom metric for multiclass

* fix alias

* fix bug

* fix indent
parent 4df7b21d
...@@ -731,7 +731,7 @@ Metric Parameters ...@@ -731,7 +731,7 @@ Metric Parameters
- ``""`` (empty string or not specified) means that metric corresponding to specified ``objective`` will be used (this is possible only for pre-defined objective functions, otherwise no evaluation metric will be added) - ``""`` (empty string or not specified) means that metric corresponding to specified ``objective`` will be used (this is possible only for pre-defined objective functions, otherwise no evaluation metric will be added)
- ``"None"`` (string, **not** a ``None`` value) means that no metric will be registered, aliases: ``na`` - ``"None"`` (string, **not** a ``None`` value) means that no metric will be registered, aliases: ``na``, ``null``, ``custom``
- ``l1``, absolute loss, aliases: ``mean_absolute_error``, ``mae``, ``regression_l1`` - ``l1``, absolute loss, aliases: ``mean_absolute_error``, ``mae``, ``regression_l1``
......
...@@ -663,7 +663,7 @@ public: ...@@ -663,7 +663,7 @@ public:
// type = multi-enum // type = multi-enum
// desc = metric(s) to be evaluated on the evaluation sets **in addition** to what is provided in the training arguments // desc = metric(s) to be evaluated on the evaluation sets **in addition** to what is provided in the training arguments
// descl2 = ``""`` (empty string or not specified) means that metric corresponding to specified ``objective`` will be used (this is possible only for pre-defined objective functions, otherwise no evaluation metric will be added) // descl2 = ``""`` (empty string or not specified) means that metric corresponding to specified ``objective`` will be used (this is possible only for pre-defined objective functions, otherwise no evaluation metric will be added)
// descl2 = ``"None"`` (string, **not** a ``None`` value) means that no metric will be registered, aliases: ``na`` // descl2 = ``"None"`` (string, **not** a ``None`` value) means that no metric will be registered, aliases: ``na``, ``null``, ``custom``
// descl2 = ``l1``, absolute loss, aliases: ``mean_absolute_error``, ``mae``, ``regression_l1`` // descl2 = ``l1``, absolute loss, aliases: ``mean_absolute_error``, ``mae``, ``regression_l1``
// descl2 = ``l2``, square loss, aliases: ``mean_squared_error``, ``mse``, ``regression_l2``, ``regression`` // descl2 = ``l2``, square loss, aliases: ``mean_squared_error``, ``mse``, ``regression_l2``, ``regression``
// descl2 = ``l2_root``, root square loss, aliases: ``root_mean_squared_error``, ``rmse`` // descl2 = ``l2_root``, root square loss, aliases: ``root_mean_squared_error``, ``rmse``
......
...@@ -202,7 +202,8 @@ bool CheckMultiClassObjective(const std::string& objective) { ...@@ -202,7 +202,8 @@ bool CheckMultiClassObjective(const std::string& objective) {
void Config::CheckParamConflict() { void Config::CheckParamConflict() {
// check if objective, metric, and num_class match // check if objective, metric, and num_class match
int num_class_check = num_class; int num_class_check = num_class;
bool objective_custom = objective == std::string("none") || objective == std::string("null") || objective == std::string("custom"); bool objective_custom = objective == std::string("none") || objective == std::string("null")
|| objective == std::string("custom") || objective == std::string("na");
bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective_custom && num_class_check > 1); bool objective_type_multiclass = CheckMultiClassObjective(objective) || (objective_custom && num_class_check > 1);
if (objective_type_multiclass) { if (objective_type_multiclass) {
...@@ -215,12 +216,15 @@ void Config::CheckParamConflict() { ...@@ -215,12 +216,15 @@ void Config::CheckParamConflict() {
} }
} }
for (std::string metric_type : metric) { for (std::string metric_type : metric) {
bool metric_custom_or_none = metric_type == std::string("none") || metric_type == std::string("null")
|| metric_type == std::string("custom") || metric_type == std::string("na");
bool metric_type_multiclass = (CheckMultiClassObjective(metric_type) bool metric_type_multiclass = (CheckMultiClassObjective(metric_type)
|| metric_type == std::string("multi_logloss") || metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error")); || metric_type == std::string("multi_error")
|| (metric_custom_or_none && num_class_check > 1));
if ((objective_type_multiclass && !metric_type_multiclass) if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) { || (!objective_type_multiclass && metric_type_multiclass)) {
Log::Fatal("Multiclass qbjective and metrics don't match"); Log::Fatal("Multiclass objective and metrics don't match");
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment