Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
c6d90bc7
Unverified
Commit
c6d90bc7
authored
Feb 12, 2025
by
James Lamb
Committed by
GitHub
Feb 12, 2025
Browse files
[python-package] support sub-classing scikit-learn estimators (#6783)
Co-authored-by:
Nikita Titov
<
nekit94-08@mail.ru
>
parent
768f6423
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
467 additions
and
11 deletions
+467
-11
docs/FAQ.rst
docs/FAQ.rst
+39
-0
python-package/lightgbm/dask.py
python-package/lightgbm/dask.py
+3
-0
python-package/lightgbm/sklearn.py
python-package/lightgbm/sklearn.py
+182
-0
tests/python_package_test/test_dask.py
tests/python_package_test/test_dask.py
+26
-10
tests/python_package_test/test_sklearn.py
tests/python_package_test/test_sklearn.py
+217
-1
No files found.
docs/FAQ.rst
View file @
c6d90bc7
...
@@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from
...
@@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from
For
some
specific
examples
,
see
`
this
comment
<
https
://
github
.
com
/
microsoft
/
LightGBM
/
issues
/
4948
#
issuecomment
-
1013766397
>`
__
.
For
some
specific
examples
,
see
`
this
comment
<
https
://
github
.
com
/
microsoft
/
LightGBM
/
issues
/
4948
#
issuecomment
-
1013766397
>`
__
.
In
addition
,
as
of
``
lightgbm
==
4.4.0
``,
the
``
conda
-
forge
``
package
automatically
supports
CUDA
-
based
GPU
acceleration
.
In
addition
,
as
of
``
lightgbm
==
4.4.0
``,
the
``
conda
-
forge
``
package
automatically
supports
CUDA
-
based
GPU
acceleration
.
5.
How
do
I
subclass
``
scikit
-
learn
``
estimators
?
-------------------------------------------------
For
``
lightgbm
<=
4.5.0
``,
copy
all
of
the
constructor
arguments
from
the
corresponding
``
lightgbm
``
class
into
the
constructor
of
your
custom
estimator
.
For
later
versions
,
just
ensure
that
the
constructor
of
your
custom
estimator
calls
``
super
().
__init__
()``.
Consider
the
example
below
,
which
implements
a
regressor
that
allows
creation
of
truncated
predictions
.
This
pattern
will
work
with
``
lightgbm
>
4.5.0
``.
..
code
-
block
::
python
import
numpy
as
np
from
lightgbm
import
LGBMRegressor
from
sklearn
.
datasets
import
make_regression
class
TruncatedRegressor
(
LGBMRegressor
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(**
kwargs
)
def
predict
(
self
,
X
,
max_score
:
float
=
np
.
inf
):
preds
=
super
().
predict
(
X
)
np
.
clip
(
preds
,
a_min
=
None
,
a_max
=
max_score
,
out
=
preds
)
return
preds
X
,
y
=
make_regression
(
n_samples
=
1
_000
,
n_features
=
4
)
reg_trunc
=
TruncatedRegressor
().
fit
(
X
,
y
)
preds
=
reg_trunc
.
predict
(
X
)
print
(
f
"mean: {preds.mean():.2f}, max: {preds.max():.2f}"
)
#
mean
:
-
6.81
,
max
:
345.10
preds_trunc
=
reg_trunc
.
predict
(
X
,
max_score
=
preds
.
mean
())
print
(
f
"mean: {preds_trunc.mean():.2f}, max: {preds_trunc.max():.2f}"
)
#
mean
:
-
56.50
,
max
:
-
6.81
python-package/lightgbm/dask.py
View file @
c6d90bc7
...
@@ -1115,6 +1115,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
...
@@ -1115,6 +1115,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
boosting_type
:
str
=
"gbdt"
,
boosting_type
:
str
=
"gbdt"
,
num_leaves
:
int
=
31
,
num_leaves
:
int
=
31
,
max_depth
:
int
=
-
1
,
max_depth
:
int
=
-
1
,
...
@@ -1318,6 +1319,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
...
@@ -1318,6 +1319,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
boosting_type
:
str
=
"gbdt"
,
boosting_type
:
str
=
"gbdt"
,
num_leaves
:
int
=
31
,
num_leaves
:
int
=
31
,
max_depth
:
int
=
-
1
,
max_depth
:
int
=
-
1
,
...
@@ -1485,6 +1487,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
...
@@ -1485,6 +1487,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
boosting_type
:
str
=
"gbdt"
,
boosting_type
:
str
=
"gbdt"
,
num_leaves
:
int
=
31
,
num_leaves
:
int
=
31
,
max_depth
:
int
=
-
1
,
max_depth
:
int
=
-
1
,
...
...
python-package/lightgbm/sklearn.py
View file @
c6d90bc7
...
@@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase):
...
@@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase):
def
__init__
(
def
__init__
(
self
,
self
,
*
,
boosting_type
:
str
=
"gbdt"
,
boosting_type
:
str
=
"gbdt"
,
num_leaves
:
int
=
31
,
num_leaves
:
int
=
31
,
max_depth
:
int
=
-
1
,
max_depth
:
int
=
-
1
,
...
@@ -745,7 +746,35 @@ class LGBMModel(_LGBMModelBase):
...
@@ -745,7 +746,35 @@ class LGBMModel(_LGBMModelBase):
params : dict
params : dict
Parameter names mapped to their values.
Parameter names mapped to their values.
"""
"""
# Based on: https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941
# which was based on: https://stackoverflow.com/questions/59248211
#
# `get_params()` flows like this:
#
# 0. Get parameters in subclass (self.__class__) first, by using inspect.
# 1. Get parameters in all parent classes (especially `LGBMModel`).
# 2. Get whatever was passed via `**kwargs`.
# 3. Merge them.
#
# This needs to accommodate being called recursively in the following
# inheritance graphs (and similar for classification and ranking):
#
# DaskLGBMRegressor -> LGBMRegressor -> LGBMModel -> BaseEstimator
# (custom subclass) -> LGBMRegressor -> LGBMModel -> BaseEstimator
# LGBMRegressor -> LGBMModel -> BaseEstimator
# (custom subclass) -> LGBMModel -> BaseEstimator
# LGBMModel -> BaseEstimator
#
params
=
super
().
get_params
(
deep
=
deep
)
params
=
super
().
get_params
(
deep
=
deep
)
cp
=
copy
.
copy
(
self
)
# If the immediate parent defines get_params(), use that.
if
callable
(
getattr
(
cp
.
__class__
.
__bases__
[
0
],
"get_params"
,
None
)):
cp
.
__class__
=
cp
.
__class__
.
__bases__
[
0
]
# Otherwise, skip it and assume the next class will have it.
# This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
else
:
cp
.
__class__
=
cp
.
__class__
.
__bases__
[
1
]
params
.
update
(
cp
.
__class__
.
get_params
(
cp
,
deep
))
params
.
update
(
self
.
_other_params
)
params
.
update
(
self
.
_other_params
)
return
params
return
params
...
@@ -1285,6 +1314,57 @@ class LGBMModel(_LGBMModelBase):
...
@@ -1285,6 +1314,57 @@ class LGBMModel(_LGBMModelBase):
class
LGBMRegressor
(
_LGBMRegressorBase
,
LGBMModel
):
class
LGBMRegressor
(
_LGBMRegressorBase
,
LGBMModel
):
"""LightGBM regressor."""
"""LightGBM regressor."""
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
# docs, help(), and tab completion.
def
__init__
(
self
,
*
,
boosting_type
:
str
=
"gbdt"
,
num_leaves
:
int
=
31
,
max_depth
:
int
=
-
1
,
learning_rate
:
float
=
0.1
,
n_estimators
:
int
=
100
,
subsample_for_bin
:
int
=
200000
,
objective
:
Optional
[
Union
[
str
,
_LGBM_ScikitCustomObjectiveFunction
]]
=
None
,
class_weight
:
Optional
[
Union
[
Dict
,
str
]]
=
None
,
min_split_gain
:
float
=
0.0
,
min_child_weight
:
float
=
1e-3
,
min_child_samples
:
int
=
20
,
subsample
:
float
=
1.0
,
subsample_freq
:
int
=
0
,
colsample_bytree
:
float
=
1.0
,
reg_alpha
:
float
=
0.0
,
reg_lambda
:
float
=
0.0
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
,
np
.
random
.
Generator
]]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
importance_type
:
str
=
"split"
,
**
kwargs
:
Any
,
)
->
None
:
super
().
__init__
(
boosting_type
=
boosting_type
,
num_leaves
=
num_leaves
,
max_depth
=
max_depth
,
learning_rate
=
learning_rate
,
n_estimators
=
n_estimators
,
subsample_for_bin
=
subsample_for_bin
,
objective
=
objective
,
class_weight
=
class_weight
,
min_split_gain
=
min_split_gain
,
min_child_weight
=
min_child_weight
,
min_child_samples
=
min_child_samples
,
subsample
=
subsample
,
subsample_freq
=
subsample_freq
,
colsample_bytree
=
colsample_bytree
,
reg_alpha
=
reg_alpha
,
reg_lambda
=
reg_lambda
,
random_state
=
random_state
,
n_jobs
=
n_jobs
,
importance_type
=
importance_type
,
**
kwargs
,
)
__init__
.
__doc__
=
LGBMModel
.
__init__
.
__doc__
def
_more_tags
(
self
)
->
Dict
[
str
,
Any
]:
def
_more_tags
(
self
)
->
Dict
[
str
,
Any
]:
# handle the case where RegressorMixin possibly provides _more_tags()
# handle the case where RegressorMixin possibly provides _more_tags()
if
callable
(
getattr
(
_LGBMRegressorBase
,
"_more_tags"
,
None
)):
if
callable
(
getattr
(
_LGBMRegressorBase
,
"_more_tags"
,
None
)):
...
@@ -1344,6 +1424,57 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
...
@@ -1344,6 +1424,57 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
class
LGBMClassifier
(
_LGBMClassifierBase
,
LGBMModel
):
class
LGBMClassifier
(
_LGBMClassifierBase
,
LGBMModel
):
"""LightGBM classifier."""
"""LightGBM classifier."""
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
# docs, help(), and tab completion.
def
__init__
(
self
,
*
,
boosting_type
:
str
=
"gbdt"
,
num_leaves
:
int
=
31
,
max_depth
:
int
=
-
1
,
learning_rate
:
float
=
0.1
,
n_estimators
:
int
=
100
,
subsample_for_bin
:
int
=
200000
,
objective
:
Optional
[
Union
[
str
,
_LGBM_ScikitCustomObjectiveFunction
]]
=
None
,
class_weight
:
Optional
[
Union
[
Dict
,
str
]]
=
None
,
min_split_gain
:
float
=
0.0
,
min_child_weight
:
float
=
1e-3
,
min_child_samples
:
int
=
20
,
subsample
:
float
=
1.0
,
subsample_freq
:
int
=
0
,
colsample_bytree
:
float
=
1.0
,
reg_alpha
:
float
=
0.0
,
reg_lambda
:
float
=
0.0
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
,
np
.
random
.
Generator
]]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
importance_type
:
str
=
"split"
,
**
kwargs
:
Any
,
)
->
None
:
super
().
__init__
(
boosting_type
=
boosting_type
,
num_leaves
=
num_leaves
,
max_depth
=
max_depth
,
learning_rate
=
learning_rate
,
n_estimators
=
n_estimators
,
subsample_for_bin
=
subsample_for_bin
,
objective
=
objective
,
class_weight
=
class_weight
,
min_split_gain
=
min_split_gain
,
min_child_weight
=
min_child_weight
,
min_child_samples
=
min_child_samples
,
subsample
=
subsample
,
subsample_freq
=
subsample_freq
,
colsample_bytree
=
colsample_bytree
,
reg_alpha
=
reg_alpha
,
reg_lambda
=
reg_lambda
,
random_state
=
random_state
,
n_jobs
=
n_jobs
,
importance_type
=
importance_type
,
**
kwargs
,
)
__init__
.
__doc__
=
LGBMModel
.
__init__
.
__doc__
def
_more_tags
(
self
)
->
Dict
[
str
,
Any
]:
def
_more_tags
(
self
)
->
Dict
[
str
,
Any
]:
# handle the case where ClassifierMixin possibly provides _more_tags()
# handle the case where ClassifierMixin possibly provides _more_tags()
if
callable
(
getattr
(
_LGBMClassifierBase
,
"_more_tags"
,
None
)):
if
callable
(
getattr
(
_LGBMClassifierBase
,
"_more_tags"
,
None
)):
...
@@ -1554,6 +1685,57 @@ class LGBMRanker(LGBMModel):
...
@@ -1554,6 +1685,57 @@ class LGBMRanker(LGBMModel):
Please use this class mainly for training and applying ranking models in common sklearnish way.
Please use this class mainly for training and applying ranking models in common sklearnish way.
"""
"""
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
# docs, help(), and tab completion.
def
__init__
(
self
,
*
,
boosting_type
:
str
=
"gbdt"
,
num_leaves
:
int
=
31
,
max_depth
:
int
=
-
1
,
learning_rate
:
float
=
0.1
,
n_estimators
:
int
=
100
,
subsample_for_bin
:
int
=
200000
,
objective
:
Optional
[
Union
[
str
,
_LGBM_ScikitCustomObjectiveFunction
]]
=
None
,
class_weight
:
Optional
[
Union
[
Dict
,
str
]]
=
None
,
min_split_gain
:
float
=
0.0
,
min_child_weight
:
float
=
1e-3
,
min_child_samples
:
int
=
20
,
subsample
:
float
=
1.0
,
subsample_freq
:
int
=
0
,
colsample_bytree
:
float
=
1.0
,
reg_alpha
:
float
=
0.0
,
reg_lambda
:
float
=
0.0
,
random_state
:
Optional
[
Union
[
int
,
np
.
random
.
RandomState
,
np
.
random
.
Generator
]]
=
None
,
n_jobs
:
Optional
[
int
]
=
None
,
importance_type
:
str
=
"split"
,
**
kwargs
:
Any
,
)
->
None
:
super
().
__init__
(
boosting_type
=
boosting_type
,
num_leaves
=
num_leaves
,
max_depth
=
max_depth
,
learning_rate
=
learning_rate
,
n_estimators
=
n_estimators
,
subsample_for_bin
=
subsample_for_bin
,
objective
=
objective
,
class_weight
=
class_weight
,
min_split_gain
=
min_split_gain
,
min_child_weight
=
min_child_weight
,
min_child_samples
=
min_child_samples
,
subsample
=
subsample
,
subsample_freq
=
subsample_freq
,
colsample_bytree
=
colsample_bytree
,
reg_alpha
=
reg_alpha
,
reg_lambda
=
reg_lambda
,
random_state
=
random_state
,
n_jobs
=
n_jobs
,
importance_type
=
importance_type
,
**
kwargs
,
)
__init__
.
__doc__
=
LGBMModel
.
__init__
.
__doc__
def
fit
(
# type: ignore[override]
def
fit
(
# type: ignore[override]
self
,
self
,
X
:
_LGBM_ScikitMatrixLike
,
X
:
_LGBM_ScikitMatrixLike
,
...
...
tests/python_package_test/test_dask.py
View file @
c6d90bc7
...
@@ -1373,26 +1373,42 @@ def test_machines_should_be_used_if_provided(task, cluster):
...
@@ -1373,26 +1373,42 @@ def test_machines_should_be_used_if_provided(task, cluster):
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"
class
es"
,
"
dask_est,sklearn_
es
t
"
,
[
[
(
lgb
.
DaskLGBMClassifier
,
lgb
.
LGBMClassifier
),
(
lgb
.
DaskLGBMClassifier
,
lgb
.
LGBMClassifier
),
(
lgb
.
DaskLGBMRegressor
,
lgb
.
LGBMRegressor
),
(
lgb
.
DaskLGBMRegressor
,
lgb
.
LGBMRegressor
),
(
lgb
.
DaskLGBMRanker
,
lgb
.
LGBMRanker
),
(
lgb
.
DaskLGBMRanker
,
lgb
.
LGBMRanker
),
],
],
)
)
def
test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg
(
classes
):
def
test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg
(
dask_est
,
sklearn_est
):
dask_spec
=
inspect
.
getfullargspec
(
classes
[
0
])
dask_spec
=
inspect
.
getfullargspec
(
dask_est
)
sklearn_spec
=
inspect
.
getfullargspec
(
classes
[
1
])
sklearn_spec
=
inspect
.
getfullargspec
(
sklearn_est
)
# should not allow for any varargs
assert
dask_spec
.
varargs
==
sklearn_spec
.
varargs
assert
dask_spec
.
varargs
==
sklearn_spec
.
varargs
assert
dask_spec
.
varargs
is
None
# the only varkw should be **kwargs,
# for pass-through to parent classes' __init__()
assert
dask_spec
.
varkw
==
sklearn_spec
.
varkw
assert
dask_spec
.
varkw
==
sklearn_spec
.
varkw
assert
dask_spec
.
kwonlyargs
==
sklearn_spec
.
kwonlyargs
assert
dask_spec
.
varkw
==
"kwargs"
assert
dask_spec
.
kwonlydefaults
==
sklearn_spec
.
kwonlydefaults
# "client" should be the only different, and the final argument
# "client" should be the only different, and the final argument
assert
dask_spec
.
args
[:
-
1
]
==
sklearn_spec
.
args
assert
dask_spec
.
kwonlyargs
==
[
*
sklearn_spec
.
kwonlyargs
,
"client"
]
assert
dask_spec
.
defaults
[:
-
1
]
==
sklearn_spec
.
defaults
assert
dask_spec
.
args
[
-
1
]
==
"client"
# default values for all constructor arguments should be identical
assert
dask_spec
.
defaults
[
-
1
]
is
None
#
# NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override
# any of LGBMModel's constructor arguments, this will need to be updated
assert
dask_spec
.
kwonlydefaults
==
{
**
sklearn_spec
.
kwonlydefaults
,
"client"
:
None
}
# only positional argument should be 'self'
assert
dask_spec
.
args
==
sklearn_spec
.
args
assert
dask_spec
.
args
==
[
"self"
]
assert
dask_spec
.
defaults
is
None
# get_params() should be identical, except for "client"
assert
dask_est
().
get_params
()
==
{
**
sklearn_est
().
get_params
(),
"client"
:
None
}
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
tests/python_package_test/test_sklearn.py
View file @
c6d90bc7
# coding: utf-8
# coding: utf-8
import
inspect
import
itertools
import
itertools
import
math
import
math
import
re
import
re
...
@@ -22,6 +23,7 @@ from sklearn.utils.validation import check_is_fitted
...
@@ -22,6 +23,7 @@ from sklearn.utils.validation import check_is_fitted
import
lightgbm
as
lgb
import
lightgbm
as
lgb
from
lightgbm.compat
import
(
from
lightgbm.compat
import
(
DASK_INSTALLED
,
DATATABLE_INSTALLED
,
DATATABLE_INSTALLED
,
PANDAS_INSTALLED
,
PANDAS_INSTALLED
,
_sklearn_version
,
_sklearn_version
,
...
@@ -83,6 +85,30 @@ class UnpicklableCallback:
...
@@ -83,6 +85,30 @@ class UnpicklableCallback:
env
.
model
.
attr_set_inside_callback
=
env
.
iteration
*
10
env
.
model
.
attr_set_inside_callback
=
env
.
iteration
*
10
class
ExtendedLGBMClassifier
(
lgb
.
LGBMClassifier
):
"""Class for testing that inheriting from LGBMClassifier works"""
def
__init__
(
self
,
*
,
some_other_param
:
str
=
"lgbm-classifier"
,
**
kwargs
):
self
.
some_other_param
=
some_other_param
super
().
__init__
(
**
kwargs
)
class
ExtendedLGBMRanker
(
lgb
.
LGBMRanker
):
"""Class for testing that inheriting from LGBMRanker works"""
def
__init__
(
self
,
*
,
some_other_param
:
str
=
"lgbm-ranker"
,
**
kwargs
):
self
.
some_other_param
=
some_other_param
super
().
__init__
(
**
kwargs
)
class
ExtendedLGBMRegressor
(
lgb
.
LGBMRegressor
):
"""Class for testing that inheriting from LGBMRegressor works"""
def
__init__
(
self
,
*
,
some_other_param
:
str
=
"lgbm-regressor"
,
**
kwargs
):
self
.
some_other_param
=
some_other_param
super
().
__init__
(
**
kwargs
)
def
custom_asymmetric_obj
(
y_true
,
y_pred
):
def
custom_asymmetric_obj
(
y_true
,
y_pred
):
residual
=
(
y_true
-
y_pred
).
astype
(
np
.
float64
)
residual
=
(
y_true
-
y_pred
).
astype
(
np
.
float64
)
grad
=
np
.
where
(
residual
<
0
,
-
2
*
10.0
*
residual
,
-
2
*
residual
)
grad
=
np
.
where
(
residual
<
0
,
-
2
*
10.0
*
residual
,
-
2
*
residual
)
...
@@ -475,6 +501,193 @@ def test_clone_and_property():
...
@@ -475,6 +501,193 @@ def test_clone_and_property():
assert
isinstance
(
clf
.
feature_importances_
,
np
.
ndarray
)
assert
isinstance
(
clf
.
feature_importances_
,
np
.
ndarray
)
@
pytest
.
mark
.
parametrize
(
"estimator"
,
(
lgb
.
LGBMClassifier
,
lgb
.
LGBMRegressor
,
lgb
.
LGBMRanker
))
def
test_estimators_all_have_the_same_kwargs_and_defaults
(
estimator
):
base_spec
=
inspect
.
getfullargspec
(
lgb
.
LGBMModel
)
subclass_spec
=
inspect
.
getfullargspec
(
estimator
)
# should not allow for any varargs
assert
subclass_spec
.
varargs
==
base_spec
.
varargs
assert
subclass_spec
.
varargs
is
None
# the only varkw should be **kwargs,
assert
subclass_spec
.
varkw
==
base_spec
.
varkw
assert
subclass_spec
.
varkw
==
"kwargs"
# default values for all constructor arguments should be identical
#
# NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override
# any of LGBMModel's constructor arguments, this will need to be updated
assert
subclass_spec
.
kwonlydefaults
==
base_spec
.
kwonlydefaults
# only positional argument should be 'self'
assert
subclass_spec
.
args
==
base_spec
.
args
assert
subclass_spec
.
args
==
[
"self"
]
assert
subclass_spec
.
defaults
is
None
# get_params() should be identical
assert
estimator
().
get_params
()
==
lgb
.
LGBMModel
().
get_params
()
def
test_subclassing_get_params_works
():
expected_params
=
{
"boosting_type"
:
"gbdt"
,
"class_weight"
:
None
,
"colsample_bytree"
:
1.0
,
"importance_type"
:
"split"
,
"learning_rate"
:
0.1
,
"max_depth"
:
-
1
,
"min_child_samples"
:
20
,
"min_child_weight"
:
0.001
,
"min_split_gain"
:
0.0
,
"n_estimators"
:
100
,
"n_jobs"
:
None
,
"num_leaves"
:
31
,
"objective"
:
None
,
"random_state"
:
None
,
"reg_alpha"
:
0.0
,
"reg_lambda"
:
0.0
,
"subsample"
:
1.0
,
"subsample_for_bin"
:
200000
,
"subsample_freq"
:
0
,
}
# Overrides, used to test that passing through **kwargs works as expected.
#
# why these?
#
# - 'n_estimators' directly matches a keyword arg for the scikit-learn estimators
# - 'eta' is a parameter alias for 'learning_rate'
overrides
=
{
"n_estimators"
:
13
,
"eta"
:
0.07
}
# lightgbm-official classes
for
est
in
[
lgb
.
LGBMModel
,
lgb
.
LGBMClassifier
,
lgb
.
LGBMRanker
,
lgb
.
LGBMRegressor
]:
assert
est
().
get_params
()
==
expected_params
assert
est
(
**
overrides
).
get_params
()
==
{
**
expected_params
,
"eta"
:
0.07
,
"n_estimators"
:
13
,
"learning_rate"
:
0.1
,
}
if
DASK_INSTALLED
:
for
est
in
[
lgb
.
DaskLGBMClassifier
,
lgb
.
DaskLGBMRanker
,
lgb
.
DaskLGBMRegressor
]:
assert
est
().
get_params
()
==
{
**
expected_params
,
"client"
:
None
,
}
assert
est
(
**
overrides
).
get_params
()
==
{
**
expected_params
,
"eta"
:
0.07
,
"n_estimators"
:
13
,
"learning_rate"
:
0.1
,
"client"
:
None
,
}
# custom sub-classes
assert
ExtendedLGBMClassifier
().
get_params
()
==
{
**
expected_params
,
"some_other_param"
:
"lgbm-classifier"
}
assert
ExtendedLGBMClassifier
(
**
overrides
).
get_params
()
==
{
**
expected_params
,
"eta"
:
0.07
,
"n_estimators"
:
13
,
"learning_rate"
:
0.1
,
"some_other_param"
:
"lgbm-classifier"
,
}
assert
ExtendedLGBMRanker
().
get_params
()
==
{
**
expected_params
,
"some_other_param"
:
"lgbm-ranker"
,
}
assert
ExtendedLGBMRanker
(
**
overrides
).
get_params
()
==
{
**
expected_params
,
"eta"
:
0.07
,
"n_estimators"
:
13
,
"learning_rate"
:
0.1
,
"some_other_param"
:
"lgbm-ranker"
,
}
assert
ExtendedLGBMRegressor
().
get_params
()
==
{
**
expected_params
,
"some_other_param"
:
"lgbm-regressor"
,
}
assert
ExtendedLGBMRegressor
(
**
overrides
).
get_params
()
==
{
**
expected_params
,
"eta"
:
0.07
,
"n_estimators"
:
13
,
"learning_rate"
:
0.1
,
"some_other_param"
:
"lgbm-regressor"
,
}
@
pytest
.
mark
.
parametrize
(
"task"
,
all_tasks
)
def
test_subclassing_works
(
task
):
# param values to make training deterministic and
# just train a small, cheap model
params
=
{
"deterministic"
:
True
,
"force_row_wise"
:
True
,
"n_jobs"
:
1
,
"n_estimators"
:
5
,
"num_leaves"
:
11
,
"random_state"
:
708
,
}
X
,
y
,
g
=
_create_data
(
task
=
task
)
if
task
==
"ranking"
:
est
=
lgb
.
LGBMRanker
(
**
params
).
fit
(
X
,
y
,
group
=
g
)
est_sub
=
ExtendedLGBMRanker
(
**
params
).
fit
(
X
,
y
,
group
=
g
)
elif
task
.
endswith
(
"classification"
):
est
=
lgb
.
LGBMClassifier
(
**
params
).
fit
(
X
,
y
)
est_sub
=
ExtendedLGBMClassifier
(
**
params
).
fit
(
X
,
y
)
else
:
est
=
lgb
.
LGBMRegressor
(
**
params
).
fit
(
X
,
y
)
est_sub
=
ExtendedLGBMRegressor
(
**
params
).
fit
(
X
,
y
)
np
.
testing
.
assert_allclose
(
est
.
predict
(
X
),
est_sub
.
predict
(
X
))
@
pytest
.
mark
.
parametrize
(
"estimator_to_task"
,
[
(
lgb
.
LGBMClassifier
,
"binary-classification"
),
(
ExtendedLGBMClassifier
,
"binary-classification"
),
(
lgb
.
LGBMRanker
,
"ranking"
),
(
ExtendedLGBMRanker
,
"ranking"
),
(
lgb
.
LGBMRegressor
,
"regression"
),
(
ExtendedLGBMRegressor
,
"regression"
),
],
)
def
test_parameter_aliases_are_handled_correctly
(
estimator_to_task
):
estimator
,
task
=
estimator_to_task
# scikit-learn estimators should remember every parameter passed
# via keyword arguments in the estimator constructor, but then
# only pass the correct value down to LightGBM's C++ side
params
=
{
"eta"
:
0.08
,
"num_iterations"
:
3
,
"num_leaves"
:
5
,
}
X
,
y
,
g
=
_create_data
(
task
=
task
)
mod
=
estimator
(
**
params
)
if
task
==
"ranking"
:
mod
.
fit
(
X
,
y
,
group
=
g
)
else
:
mod
.
fit
(
X
,
y
)
# scikit-learn get_params()
p
=
mod
.
get_params
()
assert
p
[
"eta"
]
==
0.08
assert
p
[
"learning_rate"
]
==
0.1
# lgb.Booster's 'params' attribute
p
=
mod
.
booster_
.
params
assert
p
[
"eta"
]
==
0.08
assert
p
[
"learning_rate"
]
==
0.1
# Config in the 'LightGBM::Booster' on the C++ side
p
=
mod
.
booster_
.
_get_loaded_param
()
assert
p
[
"learning_rate"
]
==
0.1
assert
"eta"
not
in
p
def
test_joblib
(
tmp_path
):
def
test_joblib
(
tmp_path
):
X
,
y
=
make_synthetic_regression
()
X
,
y
=
make_synthetic_regression
()
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
X
,
y
,
test_size
=
0.1
,
random_state
=
42
)
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
X
,
y
,
test_size
=
0.1
,
random_state
=
42
)
...
@@ -1463,7 +1676,10 @@ def _get_expected_failed_tests(estimator):
...
@@ -1463,7 +1676,10 @@ def _get_expected_failed_tests(estimator):
return
estimator
.
_more_tags
()[
"_xfail_checks"
]
return
estimator
.
_more_tags
()[
"_xfail_checks"
]
@
parametrize_with_checks
([
lgb
.
LGBMClassifier
(),
lgb
.
LGBMRegressor
()],
expected_failed_checks
=
_get_expected_failed_tests
)
@
parametrize_with_checks
(
[
ExtendedLGBMClassifier
(),
ExtendedLGBMRegressor
(),
lgb
.
LGBMClassifier
(),
lgb
.
LGBMRegressor
()],
expected_failed_checks
=
_get_expected_failed_tests
,
)
def
test_sklearn_integration
(
estimator
,
check
):
def
test_sklearn_integration
(
estimator
,
check
):
estimator
.
set_params
(
min_child_samples
=
1
,
min_data_in_bin
=
1
)
estimator
.
set_params
(
min_child_samples
=
1
,
min_data_in_bin
=
1
)
check
(
estimator
)
check
(
estimator
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment