Unverified Commit c6d90bc7 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] support sub-classing scikit-learn estimators (#6783)


Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 768f6423
...@@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from ...@@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from
For some specific examples, see `this comment <https://github.com/microsoft/LightGBM/issues/4948#issuecomment-1013766397>`__. For some specific examples, see `this comment <https://github.com/microsoft/LightGBM/issues/4948#issuecomment-1013766397>`__.
In addition, as of ``lightgbm==4.4.0``, the ``conda-forge`` package automatically supports CUDA-based GPU acceleration. In addition, as of ``lightgbm==4.4.0``, the ``conda-forge`` package automatically supports CUDA-based GPU acceleration.
5. How do I subclass ``scikit-learn`` estimators?
-------------------------------------------------
For ``lightgbm <= 4.5.0``, copy all of the constructor arguments from the corresponding
``lightgbm`` class into the constructor of your custom estimator.
For later versions, just ensure that the constructor of your custom estimator calls ``super().__init__()``.
Consider the example below, which implements a regressor that allows creation of truncated predictions.
This pattern will work with ``lightgbm > 4.5.0``.
.. code-block:: python
import numpy as np
from lightgbm import LGBMRegressor
from sklearn.datasets import make_regression
class TruncatedRegressor(LGBMRegressor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def predict(self, X, max_score: float = np.inf):
preds = super().predict(X)
np.clip(preds, a_min=None, a_max=max_score, out=preds)
return preds
X, y = make_regression(n_samples=1_000, n_features=4)
reg_trunc = TruncatedRegressor().fit(X, y)
preds = reg_trunc.predict(X)
print(f"mean: {preds.mean():.2f}, max: {preds.max():.2f}")
# mean: -6.81, max: 345.10
preds_trunc = reg_trunc.predict(X, max_score=preds.mean())
print(f"mean: {preds_trunc.mean():.2f}, max: {preds_trunc.max():.2f}")
# mean: -56.50, max: -6.81
...@@ -1115,6 +1115,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1115,6 +1115,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
def __init__( def __init__(
self, self,
*,
boosting_type: str = "gbdt", boosting_type: str = "gbdt",
num_leaves: int = 31, num_leaves: int = 31,
max_depth: int = -1, max_depth: int = -1,
...@@ -1318,6 +1319,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1318,6 +1319,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
def __init__( def __init__(
self, self,
*,
boosting_type: str = "gbdt", boosting_type: str = "gbdt",
num_leaves: int = 31, num_leaves: int = 31,
max_depth: int = -1, max_depth: int = -1,
...@@ -1485,6 +1487,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1485,6 +1487,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
def __init__( def __init__(
self, self,
*,
boosting_type: str = "gbdt", boosting_type: str = "gbdt",
num_leaves: int = 31, num_leaves: int = 31,
max_depth: int = -1, max_depth: int = -1,
......
...@@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase):
def __init__( def __init__(
self, self,
*,
boosting_type: str = "gbdt", boosting_type: str = "gbdt",
num_leaves: int = 31, num_leaves: int = 31,
max_depth: int = -1, max_depth: int = -1,
...@@ -745,7 +746,35 @@ class LGBMModel(_LGBMModelBase): ...@@ -745,7 +746,35 @@ class LGBMModel(_LGBMModelBase):
params : dict params : dict
Parameter names mapped to their values. Parameter names mapped to their values.
""" """
# Based on: https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941
# which was based on: https://stackoverflow.com/questions/59248211
#
# `get_params()` flows like this:
#
# 0. Get parameters in subclass (self.__class__) first, by using inspect.
# 1. Get parameters in all parent classes (especially `LGBMModel`).
# 2. Get whatever was passed via `**kwargs`.
# 3. Merge them.
#
# This needs to accommodate being called recursively in the following
# inheritance graphs (and similar for classification and ranking):
#
# DaskLGBMRegressor -> LGBMRegressor -> LGBMModel -> BaseEstimator
# (custom subclass) -> LGBMRegressor -> LGBMModel -> BaseEstimator
# LGBMRegressor -> LGBMModel -> BaseEstimator
# (custom subclass) -> LGBMModel -> BaseEstimator
# LGBMModel -> BaseEstimator
#
params = super().get_params(deep=deep) params = super().get_params(deep=deep)
cp = copy.copy(self)
# If the immediate parent defines get_params(), use that.
if callable(getattr(cp.__class__.__bases__[0], "get_params", None)):
cp.__class__ = cp.__class__.__bases__[0]
# Otherwise, skip it and assume the next class will have it.
# This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
else:
cp.__class__ = cp.__class__.__bases__[1]
params.update(cp.__class__.get_params(cp, deep))
params.update(self._other_params) params.update(self._other_params)
return params return params
...@@ -1285,6 +1314,57 @@ class LGBMModel(_LGBMModelBase): ...@@ -1285,6 +1314,57 @@ class LGBMModel(_LGBMModelBase):
class LGBMRegressor(_LGBMRegressorBase, LGBMModel): class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
"""LightGBM regressor.""" """LightGBM regressor."""
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
# docs, help(), and tab completion.
def __init__(
self,
*,
boosting_type: str = "gbdt",
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0.0,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.0,
subsample_freq: int = 0,
colsample_bytree: float = 1.0,
reg_alpha: float = 0.0,
reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
**kwargs: Any,
) -> None:
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
importance_type=importance_type,
**kwargs,
)
__init__.__doc__ = LGBMModel.__init__.__doc__
def _more_tags(self) -> Dict[str, Any]: def _more_tags(self) -> Dict[str, Any]:
# handle the case where RegressorMixin possibly provides _more_tags() # handle the case where RegressorMixin possibly provides _more_tags()
if callable(getattr(_LGBMRegressorBase, "_more_tags", None)): if callable(getattr(_LGBMRegressorBase, "_more_tags", None)):
...@@ -1344,6 +1424,57 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel): ...@@ -1344,6 +1424,57 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
class LGBMClassifier(_LGBMClassifierBase, LGBMModel): class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
"""LightGBM classifier.""" """LightGBM classifier."""
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
# docs, help(), and tab completion.
def __init__(
self,
*,
boosting_type: str = "gbdt",
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0.0,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.0,
subsample_freq: int = 0,
colsample_bytree: float = 1.0,
reg_alpha: float = 0.0,
reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
**kwargs: Any,
) -> None:
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
importance_type=importance_type,
**kwargs,
)
__init__.__doc__ = LGBMModel.__init__.__doc__
def _more_tags(self) -> Dict[str, Any]: def _more_tags(self) -> Dict[str, Any]:
# handle the case where ClassifierMixin possibly provides _more_tags() # handle the case where ClassifierMixin possibly provides _more_tags()
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)): if callable(getattr(_LGBMClassifierBase, "_more_tags", None)):
...@@ -1554,6 +1685,57 @@ class LGBMRanker(LGBMModel): ...@@ -1554,6 +1685,57 @@ class LGBMRanker(LGBMModel):
Please use this class mainly for training and applying ranking models in common sklearnish way. Please use this class mainly for training and applying ranking models in common sklearnish way.
""" """
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
# docs, help(), and tab completion.
def __init__(
self,
*,
boosting_type: str = "gbdt",
num_leaves: int = 31,
max_depth: int = -1,
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0.0,
min_child_weight: float = 1e-3,
min_child_samples: int = 20,
subsample: float = 1.0,
subsample_freq: int = 0,
colsample_bytree: float = 1.0,
reg_alpha: float = 0.0,
reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
**kwargs: Any,
) -> None:
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
min_child_samples=min_child_samples,
subsample=subsample,
subsample_freq=subsample_freq,
colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha,
reg_lambda=reg_lambda,
random_state=random_state,
n_jobs=n_jobs,
importance_type=importance_type,
**kwargs,
)
__init__.__doc__ = LGBMModel.__init__.__doc__
def fit( # type: ignore[override] def fit( # type: ignore[override]
self, self,
X: _LGBM_ScikitMatrixLike, X: _LGBM_ScikitMatrixLike,
......
...@@ -1373,26 +1373,42 @@ def test_machines_should_be_used_if_provided(task, cluster): ...@@ -1373,26 +1373,42 @@ def test_machines_should_be_used_if_provided(task, cluster):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"classes", "dask_est,sklearn_est",
[ [
(lgb.DaskLGBMClassifier, lgb.LGBMClassifier), (lgb.DaskLGBMClassifier, lgb.LGBMClassifier),
(lgb.DaskLGBMRegressor, lgb.LGBMRegressor), (lgb.DaskLGBMRegressor, lgb.LGBMRegressor),
(lgb.DaskLGBMRanker, lgb.LGBMRanker), (lgb.DaskLGBMRanker, lgb.LGBMRanker),
], ],
) )
def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes): def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(dask_est, sklearn_est):
dask_spec = inspect.getfullargspec(classes[0]) dask_spec = inspect.getfullargspec(dask_est)
sklearn_spec = inspect.getfullargspec(classes[1]) sklearn_spec = inspect.getfullargspec(sklearn_est)
# should not allow for any varargs
assert dask_spec.varargs == sklearn_spec.varargs assert dask_spec.varargs == sklearn_spec.varargs
assert dask_spec.varargs is None
# the only varkw should be **kwargs,
# for pass-through to parent classes' __init__()
assert dask_spec.varkw == sklearn_spec.varkw assert dask_spec.varkw == sklearn_spec.varkw
assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs assert dask_spec.varkw == "kwargs"
assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults
# "client" should be the only different, and the final argument # "client" should be the only different, and the final argument
assert dask_spec.args[:-1] == sklearn_spec.args assert dask_spec.kwonlyargs == [*sklearn_spec.kwonlyargs, "client"]
assert dask_spec.defaults[:-1] == sklearn_spec.defaults
assert dask_spec.args[-1] == "client" # default values for all constructor arguments should be identical
assert dask_spec.defaults[-1] is None #
# NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override
# any of LGBMModel's constructor arguments, this will need to be updated
assert dask_spec.kwonlydefaults == {**sklearn_spec.kwonlydefaults, "client": None}
# only positional argument should be 'self'
assert dask_spec.args == sklearn_spec.args
assert dask_spec.args == ["self"]
assert dask_spec.defaults is None
# get_params() should be identical, except for "client"
assert dask_est().get_params() == {**sklearn_est().get_params(), "client": None}
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
# coding: utf-8 # coding: utf-8
import inspect
import itertools import itertools
import math import math
import re import re
...@@ -22,6 +23,7 @@ from sklearn.utils.validation import check_is_fitted ...@@ -22,6 +23,7 @@ from sklearn.utils.validation import check_is_fitted
import lightgbm as lgb import lightgbm as lgb
from lightgbm.compat import ( from lightgbm.compat import (
DASK_INSTALLED,
DATATABLE_INSTALLED, DATATABLE_INSTALLED,
PANDAS_INSTALLED, PANDAS_INSTALLED,
_sklearn_version, _sklearn_version,
...@@ -83,6 +85,30 @@ class UnpicklableCallback: ...@@ -83,6 +85,30 @@ class UnpicklableCallback:
env.model.attr_set_inside_callback = env.iteration * 10 env.model.attr_set_inside_callback = env.iteration * 10
class ExtendedLGBMClassifier(lgb.LGBMClassifier):
"""Class for testing that inheriting from LGBMClassifier works"""
def __init__(self, *, some_other_param: str = "lgbm-classifier", **kwargs):
self.some_other_param = some_other_param
super().__init__(**kwargs)
class ExtendedLGBMRanker(lgb.LGBMRanker):
"""Class for testing that inheriting from LGBMRanker works"""
def __init__(self, *, some_other_param: str = "lgbm-ranker", **kwargs):
self.some_other_param = some_other_param
super().__init__(**kwargs)
class ExtendedLGBMRegressor(lgb.LGBMRegressor):
"""Class for testing that inheriting from LGBMRegressor works"""
def __init__(self, *, some_other_param: str = "lgbm-regressor", **kwargs):
self.some_other_param = some_other_param
super().__init__(**kwargs)
def custom_asymmetric_obj(y_true, y_pred): def custom_asymmetric_obj(y_true, y_pred):
residual = (y_true - y_pred).astype(np.float64) residual = (y_true - y_pred).astype(np.float64)
grad = np.where(residual < 0, -2 * 10.0 * residual, -2 * residual) grad = np.where(residual < 0, -2 * 10.0 * residual, -2 * residual)
...@@ -475,6 +501,193 @@ def test_clone_and_property(): ...@@ -475,6 +501,193 @@ def test_clone_and_property():
assert isinstance(clf.feature_importances_, np.ndarray) assert isinstance(clf.feature_importances_, np.ndarray)
@pytest.mark.parametrize("estimator", (lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker))
def test_estimators_all_have_the_same_kwargs_and_defaults(estimator):
base_spec = inspect.getfullargspec(lgb.LGBMModel)
subclass_spec = inspect.getfullargspec(estimator)
# should not allow for any varargs
assert subclass_spec.varargs == base_spec.varargs
assert subclass_spec.varargs is None
# the only varkw should be **kwargs,
assert subclass_spec.varkw == base_spec.varkw
assert subclass_spec.varkw == "kwargs"
# default values for all constructor arguments should be identical
#
# NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override
# any of LGBMModel's constructor arguments, this will need to be updated
assert subclass_spec.kwonlydefaults == base_spec.kwonlydefaults
# only positional argument should be 'self'
assert subclass_spec.args == base_spec.args
assert subclass_spec.args == ["self"]
assert subclass_spec.defaults is None
# get_params() should be identical
assert estimator().get_params() == lgb.LGBMModel().get_params()
def test_subclassing_get_params_works():
expected_params = {
"boosting_type": "gbdt",
"class_weight": None,
"colsample_bytree": 1.0,
"importance_type": "split",
"learning_rate": 0.1,
"max_depth": -1,
"min_child_samples": 20,
"min_child_weight": 0.001,
"min_split_gain": 0.0,
"n_estimators": 100,
"n_jobs": None,
"num_leaves": 31,
"objective": None,
"random_state": None,
"reg_alpha": 0.0,
"reg_lambda": 0.0,
"subsample": 1.0,
"subsample_for_bin": 200000,
"subsample_freq": 0,
}
# Overrides, used to test that passing through **kwargs works as expected.
#
# why these?
#
# - 'n_estimators' directly matches a keyword arg for the scikit-learn estimators
# - 'eta' is a parameter alias for 'learning_rate'
overrides = {"n_estimators": 13, "eta": 0.07}
# lightgbm-official classes
for est in [lgb.LGBMModel, lgb.LGBMClassifier, lgb.LGBMRanker, lgb.LGBMRegressor]:
assert est().get_params() == expected_params
assert est(**overrides).get_params() == {
**expected_params,
"eta": 0.07,
"n_estimators": 13,
"learning_rate": 0.1,
}
if DASK_INSTALLED:
for est in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRanker, lgb.DaskLGBMRegressor]:
assert est().get_params() == {
**expected_params,
"client": None,
}
assert est(**overrides).get_params() == {
**expected_params,
"eta": 0.07,
"n_estimators": 13,
"learning_rate": 0.1,
"client": None,
}
# custom sub-classes
assert ExtendedLGBMClassifier().get_params() == {**expected_params, "some_other_param": "lgbm-classifier"}
assert ExtendedLGBMClassifier(**overrides).get_params() == {
**expected_params,
"eta": 0.07,
"n_estimators": 13,
"learning_rate": 0.1,
"some_other_param": "lgbm-classifier",
}
assert ExtendedLGBMRanker().get_params() == {
**expected_params,
"some_other_param": "lgbm-ranker",
}
assert ExtendedLGBMRanker(**overrides).get_params() == {
**expected_params,
"eta": 0.07,
"n_estimators": 13,
"learning_rate": 0.1,
"some_other_param": "lgbm-ranker",
}
assert ExtendedLGBMRegressor().get_params() == {
**expected_params,
"some_other_param": "lgbm-regressor",
}
assert ExtendedLGBMRegressor(**overrides).get_params() == {
**expected_params,
"eta": 0.07,
"n_estimators": 13,
"learning_rate": 0.1,
"some_other_param": "lgbm-regressor",
}
@pytest.mark.parametrize("task", all_tasks)
def test_subclassing_works(task):
# param values to make training deterministic and
# just train a small, cheap model
params = {
"deterministic": True,
"force_row_wise": True,
"n_jobs": 1,
"n_estimators": 5,
"num_leaves": 11,
"random_state": 708,
}
X, y, g = _create_data(task=task)
if task == "ranking":
est = lgb.LGBMRanker(**params).fit(X, y, group=g)
est_sub = ExtendedLGBMRanker(**params).fit(X, y, group=g)
elif task.endswith("classification"):
est = lgb.LGBMClassifier(**params).fit(X, y)
est_sub = ExtendedLGBMClassifier(**params).fit(X, y)
else:
est = lgb.LGBMRegressor(**params).fit(X, y)
est_sub = ExtendedLGBMRegressor(**params).fit(X, y)
np.testing.assert_allclose(est.predict(X), est_sub.predict(X))
@pytest.mark.parametrize(
"estimator_to_task",
[
(lgb.LGBMClassifier, "binary-classification"),
(ExtendedLGBMClassifier, "binary-classification"),
(lgb.LGBMRanker, "ranking"),
(ExtendedLGBMRanker, "ranking"),
(lgb.LGBMRegressor, "regression"),
(ExtendedLGBMRegressor, "regression"),
],
)
def test_parameter_aliases_are_handled_correctly(estimator_to_task):
estimator, task = estimator_to_task
# scikit-learn estimators should remember every parameter passed
# via keyword arguments in the estimator constructor, but then
# only pass the correct value down to LightGBM's C++ side
params = {
"eta": 0.08,
"num_iterations": 3,
"num_leaves": 5,
}
X, y, g = _create_data(task=task)
mod = estimator(**params)
if task == "ranking":
mod.fit(X, y, group=g)
else:
mod.fit(X, y)
# scikit-learn get_params()
p = mod.get_params()
assert p["eta"] == 0.08
assert p["learning_rate"] == 0.1
# lgb.Booster's 'params' attribute
p = mod.booster_.params
assert p["eta"] == 0.08
assert p["learning_rate"] == 0.1
# Config in the 'LightGBM::Booster' on the C++ side
p = mod.booster_._get_loaded_param()
assert p["learning_rate"] == 0.1
assert "eta" not in p
def test_joblib(tmp_path): def test_joblib(tmp_path):
X, y = make_synthetic_regression() X, y = make_synthetic_regression()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
...@@ -1463,7 +1676,10 @@ def _get_expected_failed_tests(estimator): ...@@ -1463,7 +1676,10 @@ def _get_expected_failed_tests(estimator):
return estimator._more_tags()["_xfail_checks"] return estimator._more_tags()["_xfail_checks"]
@parametrize_with_checks([lgb.LGBMClassifier(), lgb.LGBMRegressor()], expected_failed_checks=_get_expected_failed_tests) @parametrize_with_checks(
[ExtendedLGBMClassifier(), ExtendedLGBMRegressor(), lgb.LGBMClassifier(), lgb.LGBMRegressor()],
expected_failed_checks=_get_expected_failed_tests,
)
def test_sklearn_integration(estimator, check): def test_sklearn_integration(estimator, check):
estimator.set_params(min_child_samples=1, min_data_in_bin=1) estimator.set_params(min_child_samples=1, min_data_in_bin=1)
check(estimator) check(estimator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment