Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
ca066d49
Unverified
Commit
ca066d49
authored
Sep 02, 2020
by
Nikita Titov
Committed by
GitHub
Sep 02, 2020
Browse files
be compatible with check_is_fitted sklearn function (#3329)
parent
8fc80bb4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
0 deletions
+23
-0
python-package/lightgbm/sklearn.py
python-package/lightgbm/sklearn.py
+2
-0
tests/python_package_test/test_sklearn.py
tests/python_package_test/test_sklearn.py
+21
-0
No files found.
python-package/lightgbm/sklearn.py
View file @
ca066d49
...
...
@@ -607,6 +607,8 @@ class LGBMModel(_LGBMModelBase):
self
.
_best_score
=
self
.
_Booster
.
best_score
self
.
fitted_
=
True
# free dataset
self
.
_Booster
.
free_dataset
()
del
train_set
,
valid_sets
...
...
tests/python_package_test/test_sklearn.py
View file @
ca066d49
...
...
@@ -20,6 +20,7 @@ from sklearn.multioutput import (MultiOutputClassifier, ClassifierChain, MultiOu
RegressorChain
)
from
sklearn.utils.estimator_checks
import
(
_yield_all_checks
,
SkipTest
,
check_parameters_default_constructible
)
from
sklearn.utils.validation
import
check_is_fitted
decreasing_generator
=
itertools
.
count
(
0
,
-
1
)
...
...
@@ -1091,3 +1092,23 @@ class TestSklearn(unittest.TestCase):
self
.
assertEqual
(
len
(
init_gbm
.
evals_result_
[
'valid_0'
][
'multi_logloss'
]),
5
)
self
.
assertLess
(
gbm
.
evals_result_
[
'valid_0'
][
'multi_logloss'
][
-
1
],
init_gbm
.
evals_result_
[
'valid_0'
][
'multi_logloss'
][
-
1
])
# sklearn < 0.22 requires passing "attributes" argument
@
unittest
.
skipIf
(
sk_version
<
'0.22.0'
,
'scikit-learn version is less than 0.22'
)
def
test_check_is_fitted
(
self
):
X
,
y
=
load_digits
(
n_class
=
2
,
return_X_y
=
True
)
est
=
lgb
.
LGBMModel
(
n_estimators
=
5
,
objective
=
"binary"
)
clf
=
lgb
.
LGBMClassifier
(
n_estimators
=
5
)
reg
=
lgb
.
LGBMRegressor
(
n_estimators
=
5
)
rnk
=
lgb
.
LGBMRanker
(
n_estimators
=
5
)
models
=
(
est
,
clf
,
reg
,
rnk
)
for
model
in
models
:
self
.
assertRaises
(
lgb
.
compat
.
LGBMNotFittedError
,
check_is_fitted
,
model
)
est
.
fit
(
X
,
y
)
clf
.
fit
(
X
,
y
)
reg
.
fit
(
X
,
y
)
rnk
.
fit
(
X
,
y
,
group
=
np
.
ones
(
X
.
shape
[
0
]))
for
model
in
models
:
check_is_fitted
(
model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment