Unverified Commit 6e0b0a8b authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] simplify scikit-learn 1.6+ tags support (#6735)

parent ea04c66c
...@@ -14,14 +14,6 @@ try: ...@@ -14,14 +14,6 @@ try:
from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
# sklearn.utils Tags types can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import ClassifierTags as _sklearn_ClassifierTags
from sklearn.utils import RegressorTags as _sklearn_RegressorTags
except ImportError:
_sklearn_ClassifierTags = None
_sklearn_RegressorTags = None
try: try:
from sklearn.exceptions import NotFittedError from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
...@@ -148,8 +140,6 @@ except ImportError: ...@@ -148,8 +140,6 @@ except ImportError:
_LGBMCheckClassificationTargets = None _LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None _LGBMComputeSampleWeight = None
_LGBMValidateData = None _LGBMValidateData = None
_sklearn_ClassifierTags = None
_sklearn_RegressorTags = None
_sklearn_version = None _sklearn_version = None
# additional scikit-learn imports only for type hints # additional scikit-learn imports only for type hints
......
...@@ -40,8 +40,6 @@ from .compat import ( ...@@ -40,8 +40,6 @@ from .compat import (
_LGBMModelBase, _LGBMModelBase,
_LGBMRegressorBase, _LGBMRegressorBase,
_LGBMValidateData, _LGBMValidateData,
_sklearn_ClassifierTags,
_sklearn_RegressorTags,
_sklearn_version, _sklearn_version,
dt_DataTable, dt_DataTable,
pd_DataFrame, pd_DataFrame,
...@@ -726,7 +724,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -726,7 +724,7 @@ class LGBMModel(_LGBMModelBase):
# take whatever tags are provided by BaseEstimator, then modify # take whatever tags are provided by BaseEstimator, then modify
# them with LightGBM-specific values # them with LightGBM-specific values
return self._update_sklearn_tags_from_dict( return self._update_sklearn_tags_from_dict(
tags=_LGBMModelBase.__sklearn_tags__(self), tags=super().__sklearn_tags__(),
tags_dict=self._more_tags(), tags_dict=self._more_tags(),
) )
...@@ -1298,10 +1296,7 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel): ...@@ -1298,10 +1296,7 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
return tags return tags
def __sklearn_tags__(self) -> "_sklearn_Tags": def __sklearn_tags__(self) -> "_sklearn_Tags":
tags = LGBMModel.__sklearn_tags__(self) return super().__sklearn_tags__()
tags.estimator_type = "regressor"
tags.regressor_tags = _sklearn_RegressorTags(multi_label=False)
return tags
def fit( # type: ignore[override] def fit( # type: ignore[override]
self, self,
...@@ -1360,9 +1355,9 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1360,9 +1355,9 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
return tags return tags
def __sklearn_tags__(self) -> "_sklearn_Tags": def __sklearn_tags__(self) -> "_sklearn_Tags":
tags = LGBMModel.__sklearn_tags__(self) tags = super().__sklearn_tags__()
tags.estimator_type = "classifier" tags.classifier_tags.multi_class = True
tags.classifier_tags = _sklearn_ClassifierTags(multi_class=True, multi_label=False) tags.classifier_tags.multi_label = False
return tags return tags
def fit( # type: ignore[override] def fit( # type: ignore[override]
......
...@@ -1488,6 +1488,12 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato ...@@ -1488,6 +1488,12 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato
assert sklearn_tags.input_tags.allow_nan is True assert sklearn_tags.input_tags.allow_nan is True
assert sklearn_tags.input_tags.sparse is True assert sklearn_tags.input_tags.sparse is True
assert sklearn_tags.target_tags.one_d_labels is True assert sklearn_tags.target_tags.one_d_labels is True
if estimator_class is lgb.LGBMClassifier:
assert sklearn_tags.estimator_type == "classifier"
assert sklearn_tags.classifier_tags.multi_class is True
assert sklearn_tags.classifier_tags.multi_label is False
elif estimator_class is lgb.LGBMRegressor:
assert sklearn_tags.estimator_type == "regressor"
@pytest.mark.parametrize("task", all_tasks) @pytest.mark.parametrize("task", all_tasks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment