"googlemock/vscode:/vscode.git/clone" did not exist on "fa87209829d47f3fd7c6faf4e99d7d8799e3563d"
Unverified Commit f8ec57b8 authored by RektPunk's avatar RektPunk Committed by GitHub
Browse files

[python-package] Correctly recognize LGBMClassifier(num_class=2,...

[python-package] Correctly recognize LGBMClassifier(num_class=2, objective="multiclass") as multiclass classification (#6524)
parent 3d026629
...@@ -157,6 +157,8 @@ _LGBM_SetFieldType = Union[ ...@@ -157,6 +157,8 @@ _LGBM_SetFieldType = Union[
ZERO_THRESHOLD = 1e-35 ZERO_THRESHOLD = 1e-35
_MULTICLASS_OBJECTIVES = {"multiclass", "multiclassova", "multiclass_ova", "ova", "ovr", "softmax"}
def _is_zero(x: float) -> bool: def _is_zero(x: float) -> bool:
return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD
......
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
import scipy.sparse import scipy.sparse
from .basic import ( from .basic import (
_MULTICLASS_OBJECTIVES,
Booster, Booster,
Dataset, Dataset,
LightGBMError, LightGBMError,
...@@ -467,7 +468,7 @@ def _extract_evaluation_meta_data( ...@@ -467,7 +468,7 @@ def _extract_evaluation_meta_data(
# It's possible, for example, to pass 3 eval sets through `eval_set`, # It's possible, for example, to pass 3 eval sets through `eval_set`,
# but only 1 init_score through `eval_init_score`. # but only 1 init_score through `eval_init_score`.
# #
# This if-else accounts for that possiblity. # This if-else accounts for that possibility.
if len(collection) > i: if len(collection) > i:
return collection[i] return collection[i]
else: else:
...@@ -1011,7 +1012,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -1011,7 +1012,7 @@ class LGBMModel(_LGBMModelBase):
f"match the input. Model n_features_ is {self._n_features} and " f"match the input. Model n_features_ is {self._n_features} and "
f"input n_features is {n_features}" f"input n_features is {n_features}"
) )
# retrive original params that possibly can be used in both training and prediction # retrieve original params that possibly can be used in both training and prediction
# and then overwrite them (considering aliases) with params that were passed directly in prediction # and then overwrite them (considering aliases) with params that were passed directly in prediction
predict_params = self._process_params(stage="predict") predict_params = self._process_params(stage="predict")
for alias in _ConfigAliases.get_by_alias( for alias in _ConfigAliases.get_by_alias(
...@@ -1251,7 +1252,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1251,7 +1252,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
eval_metric_list = [eval_metric] eval_metric_list = [eval_metric]
else: else:
eval_metric_list = [] eval_metric_list = []
if self._n_classes > 2: if self.__is_multiclass:
for index, metric in enumerate(eval_metric_list): for index, metric in enumerate(eval_metric_list):
if metric in {"logloss", "binary_logloss"}: if metric in {"logloss", "binary_logloss"}:
eval_metric_list[index] = "multi_logloss" eval_metric_list[index] = "multi_logloss"
...@@ -1361,7 +1362,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1361,7 +1362,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
"Returning raw scores instead." "Returning raw scores instead."
) )
return result return result
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator] elif self.__is_multiclass or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
return result return result
else: else:
return np.vstack((1.0 - result, result)).transpose() return np.vstack((1.0 - result, result)).transpose()
...@@ -1389,6 +1390,11 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1389,6 +1390,11 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
raise LGBMNotFittedError("No classes found. Need to call fit beforehand.") raise LGBMNotFittedError("No classes found. Need to call fit beforehand.")
return self._n_classes return self._n_classes
@property
def __is_multiclass(self) -> bool:
""":obj:`bool`: Indicator of whether the classifier is used for multiclass."""
return self._n_classes > 2 or (isinstance(self._objective, str) and self._objective in _MULTICLASS_OBJECTIVES)
class LGBMRanker(LGBMModel): class LGBMRanker(LGBMModel):
"""LightGBM ranker. """LightGBM ranker.
......
...@@ -719,6 +719,25 @@ def test_predict(): ...@@ -719,6 +719,25 @@ def test_predict():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
np.testing.assert_allclose(res_engine, res_sklearn_params) np.testing.assert_allclose(res_engine, res_sklearn_params)
# Test multiclass binary classification
num_samples = 100
num_classes = 2
X_train = np.linspace(start=0, stop=10, num=num_samples * 3).reshape(num_samples, 3)
y_train = np.concatenate([np.zeros(int(num_samples / 2 - 10)), np.ones(int(num_samples / 2 + 10))])
gbm = lgb.train({"objective": "multiclass", "num_class": num_classes, "verbose": -1}, lgb.Dataset(X_train, y_train))
clf = lgb.LGBMClassifier(objective="multiclass", num_classes=num_classes).fit(X_train, y_train)
res_engine = gbm.predict(X_train)
res_sklearn = clf.predict_proba(X_train)
assert res_engine.shape == (num_samples, num_classes)
assert res_sklearn.shape == (num_samples, num_classes)
np.testing.assert_allclose(res_engine, res_sklearn)
res_class_sklearn = clf.predict(X_train)
np.testing.assert_allclose(res_class_sklearn, y_train)
def test_predict_with_params_from_init(): def test_predict_with_params_from_init():
X, y = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
...@@ -1035,6 +1054,20 @@ def test_metrics(): ...@@ -1035,6 +1054,20 @@ def test_metrics():
assert len(gbm.evals_result_["training"]) == 1 assert len(gbm.evals_result_["training"]) == 1
assert "binary_logloss" in gbm.evals_result_["training"] assert "binary_logloss" in gbm.evals_result_["training"]
# the evaluation metric changes to multiclass metric even num classes is 2 for multiclass objective
gbm = lgb.LGBMClassifier(objective="multiclass", num_classes=2, **params).fit(
eval_metric="binary_logloss", **params_fit
)
assert len(gbm._evals_result["training"]) == 1
assert "multi_logloss" in gbm.evals_result_["training"]
# the evaluation metric changes to multiclass metric even num classes is 2 for ovr objective
gbm = lgb.LGBMClassifier(objective="ovr", num_classes=2, **params).fit(eval_metric="binary_error", **params_fit)
assert gbm.objective_ == "ovr"
assert len(gbm.evals_result_["training"]) == 2
assert "multi_logloss" in gbm.evals_result_["training"]
assert "multi_error" in gbm.evals_result_["training"]
def test_multiple_eval_metrics(): def test_multiple_eval_metrics():
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment