"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "aa304c9c7d725ffb9d10af08a3b34cb372307020"
Unverified Commit 6e0b0a8b authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] simplify scikit-learn 1.6+ tags support (#6735)

parent ea04c66c
......@@ -14,14 +14,6 @@ try:
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
# sklearn.utils Tags types can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import ClassifierTags as _sklearn_ClassifierTags
from sklearn.utils import RegressorTags as _sklearn_RegressorTags
except ImportError:
_sklearn_ClassifierTags = None
_sklearn_RegressorTags = None
try:
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
......@@ -148,8 +140,6 @@ except ImportError:
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
_LGBMValidateData = None
_sklearn_ClassifierTags = None
_sklearn_RegressorTags = None
_sklearn_version = None
# additional scikit-learn imports only for type hints
......
......@@ -40,8 +40,6 @@ from .compat import (
_LGBMModelBase,
_LGBMRegressorBase,
_LGBMValidateData,
_sklearn_ClassifierTags,
_sklearn_RegressorTags,
_sklearn_version,
dt_DataTable,
pd_DataFrame,
......@@ -726,7 +724,7 @@ class LGBMModel(_LGBMModelBase):
# take whatever tags are provided by BaseEstimator, then modify
# them with LightGBM-specific values
return self._update_sklearn_tags_from_dict(
tags=_LGBMModelBase.__sklearn_tags__(self),
tags=super().__sklearn_tags__(),
tags_dict=self._more_tags(),
)
......@@ -1298,10 +1296,7 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
return tags
def __sklearn_tags__(self) -> "_sklearn_Tags":
tags = LGBMModel.__sklearn_tags__(self)
tags.estimator_type = "regressor"
tags.regressor_tags = _sklearn_RegressorTags(multi_label=False)
return tags
return super().__sklearn_tags__()
def fit( # type: ignore[override]
self,
......@@ -1360,9 +1355,9 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
return tags
def __sklearn_tags__(self) -> "_sklearn_Tags":
tags = LGBMModel.__sklearn_tags__(self)
tags.estimator_type = "classifier"
tags.classifier_tags = _sklearn_ClassifierTags(multi_class=True, multi_label=False)
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_class = True
tags.classifier_tags.multi_label = False
return tags
def fit( # type: ignore[override]
......
......@@ -1488,6 +1488,12 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato
assert sklearn_tags.input_tags.allow_nan is True
assert sklearn_tags.input_tags.sparse is True
assert sklearn_tags.target_tags.one_d_labels is True
if estimator_class is lgb.LGBMClassifier:
assert sklearn_tags.estimator_type == "classifier"
assert sklearn_tags.classifier_tags.multi_class is True
assert sklearn_tags.classifier_tags.multi_label is False
elif estimator_class is lgb.LGBMRegressor:
assert sklearn_tags.estimator_type == "regressor"
@pytest.mark.parametrize("task", all_tasks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment