Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
5e592fe6
Unverified
Commit
5e592fe6
authored
Sep 12, 2023
by
david-cortes
Committed by
GitHub
Sep 11, 2023
Browse files
[python-package] Fix misdetected objective after multiple calls to `LGBMClassifier.fit` (#6002)
parent
501ce1cb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
0 deletions
+19
-0
python-package/lightgbm/sklearn.py
python-package/lightgbm/sklearn.py
+2
-0
tests/python_package_test/test_sklearn.py
tests/python_package_test/test_sklearn.py
+17
-0
No files found.
python-package/lightgbm/sklearn.py
View file @
5e592fe6
...
@@ -1103,6 +1103,8 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
...
@@ -1103,6 +1103,8 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
self
.
_classes
=
self
.
_le
.
classes_
self
.
_classes
=
self
.
_le
.
classes_
self
.
_n_classes
=
len
(
self
.
_classes
)
# type: ignore[arg-type]
self
.
_n_classes
=
len
(
self
.
_classes
)
# type: ignore[arg-type]
if
self
.
objective
is
None
:
self
.
_objective
=
None
# adjust eval metrics to match whether binary or multiclass
# adjust eval metrics to match whether binary or multiclass
# classification is being performed
# classification is being performed
...
...
tests/python_package_test/test_sklearn.py
View file @
5e592fe6
...
@@ -1561,3 +1561,20 @@ def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type
...
@@ -1561,3 +1561,20 @@ def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type
)
)
preds
=
model
.
predict
(
X
)
preds
=
model
.
predict
(
X
)
assert
spearmanr
(
preds
,
y
).
correlation
>=
0.99
assert
spearmanr
(
preds
,
y
).
correlation
>=
0.99
def
test_classifier_fit_detects_classes_every_time
():
rng
=
np
.
random
.
default_rng
(
seed
=
123
)
nrows
=
1000
ncols
=
20
X
=
rng
.
standard_normal
(
size
=
(
nrows
,
ncols
))
y_bin
=
(
rng
.
random
(
size
=
nrows
)
<=
.
3
).
astype
(
np
.
float64
)
y_multi
=
rng
.
integers
(
4
,
size
=
nrows
)
model
=
lgb
.
LGBMClassifier
(
verbose
=-
1
)
for
_
in
range
(
2
):
model
.
fit
(
X
,
y_multi
)
assert
model
.
objective_
==
"multiclass"
model
.
fit
(
X
,
y_bin
)
assert
model
.
objective_
==
"binary"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment