Unverified Commit 4887b3b0 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors around eval metrics in sklearn.py (#5719)

parent bacb33d1
......@@ -1059,21 +1059,28 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
self._classes = self._le.classes_
self._n_classes = len(self._classes)
# adjust eval metrics to match whether binary or multiclass
# classification is being performed
if not callable(eval_metric):
if isinstance(eval_metric, (str, type(None))):
eval_metric = [eval_metric]
if isinstance(eval_metric, list):
eval_metric_list = eval_metric
elif isinstance(eval_metric, str):
eval_metric_list = [eval_metric]
else:
eval_metric_list = []
if self._n_classes > 2:
for index, metric in enumerate(eval_metric):
for index, metric in enumerate(eval_metric_list):
if metric in {'logloss', 'binary_logloss'}:
eval_metric[index] = "multi_logloss"
eval_metric_list[index] = "multi_logloss"
elif metric in {'error', 'binary_error'}:
eval_metric[index] = "multi_error"
eval_metric_list[index] = "multi_error"
else:
for index, metric in enumerate(eval_metric):
for index, metric in enumerate(eval_metric_list):
if metric in {'logloss', 'multi_logloss'}:
eval_metric[index] = 'binary_logloss'
eval_metric_list[index] = 'binary_logloss'
elif metric in {'error', 'multi_error'}:
eval_metric[index] = 'binary_error'
eval_metric_list[index] = 'binary_error'
eval_metric = eval_metric_list
# do not modify args, as it causes errors in model selection tools
valid_sets = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment