Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
5f79626f
Unverified
Commit
5f79626f
authored
Mar 30, 2023
by
James Lamb
Committed by
GitHub
Mar 30, 2023
Browse files
[python-package] fix type annotations for eval result tracking (#5793)
parent
42a42670
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
9 deletions
+18
-9
python-package/lightgbm/callback.py
python-package/lightgbm/callback.py
+14
-6
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+4
-3
No files found.
python-package/lightgbm/callback.py
View file @
5f79626f
...
@@ -15,6 +15,10 @@ __all__ = [
...
@@ -15,6 +15,10 @@ __all__ = [
_EvalResultDict
=
Dict
[
str
,
Dict
[
str
,
List
[
Any
]]]
_EvalResultDict
=
Dict
[
str
,
Dict
[
str
,
List
[
Any
]]]
_EvalResultTuple
=
Union
[
_EvalResultTuple
=
Union
[
_LGBM_BoosterEvalMethodResultType
,
Tuple
[
str
,
str
,
float
,
bool
,
float
]
]
_ListOfEvalResultTuples
=
Union
[
List
[
_LGBM_BoosterEvalMethodResultType
],
List
[
_LGBM_BoosterEvalMethodResultType
],
List
[
Tuple
[
str
,
str
,
float
,
bool
,
float
]]
List
[
Tuple
[
str
,
str
,
float
,
bool
,
float
]]
]
]
...
@@ -23,7 +27,7 @@ _EvalResultTuple = Union[
...
@@ -23,7 +27,7 @@ _EvalResultTuple = Union[
class
EarlyStopException
(
Exception
):
class
EarlyStopException
(
Exception
):
"""Exception of early stopping."""
"""Exception of early stopping."""
def
__init__
(
self
,
best_iteration
:
int
,
best_score
:
_EvalResultTuple
)
->
None
:
def
__init__
(
self
,
best_iteration
:
int
,
best_score
:
_
ListOf
EvalResultTuple
s
)
->
None
:
"""Create early stopping exception.
"""Create early stopping exception.
Parameters
Parameters
...
@@ -55,7 +59,7 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
...
@@ -55,7 +59,7 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
return
f
"
{
value
[
0
]
}
's
{
value
[
1
]
}
:
{
value
[
2
]:
g
}
"
return
f
"
{
value
[
0
]
}
's
{
value
[
1
]
}
:
{
value
[
2
]:
g
}
"
elif
len
(
value
)
==
5
:
elif
len
(
value
)
==
5
:
if
show_stdv
:
if
show_stdv
:
return
f
"
{
value
[
0
]
}
's
{
value
[
1
]
}
:
{
value
[
2
]:
g
}
+
{
value
[
4
]:
g
}
"
return
f
"
{
value
[
0
]
}
's
{
value
[
1
]
}
:
{
value
[
2
]:
g
}
+
{
value
[
4
]:
g
}
"
# type: ignore[misc]
else
:
else
:
return
f
"
{
value
[
0
]
}
's
{
value
[
1
]
}
:
{
value
[
2
]:
g
}
"
return
f
"
{
value
[
0
]
}
's
{
value
[
1
]
}
:
{
value
[
2
]:
g
}
"
else
:
else
:
...
@@ -256,7 +260,7 @@ class _EarlyStoppingCallback:
...
@@ -256,7 +260,7 @@ class _EarlyStoppingCallback:
def
_reset_storages
(
self
)
->
None
:
def
_reset_storages
(
self
)
->
None
:
self
.
best_score
:
List
[
float
]
=
[]
self
.
best_score
:
List
[
float
]
=
[]
self
.
best_iter
:
List
[
int
]
=
[]
self
.
best_iter
:
List
[
int
]
=
[]
self
.
best_score_list
:
List
[
Union
[
_
EvalResultTuple
,
None
]
]
=
[]
self
.
best_score_list
:
List
[
_ListOf
EvalResultTuple
s
]
=
[]
self
.
cmp_op
:
List
[
Callable
[[
float
,
float
],
bool
]]
=
[]
self
.
cmp_op
:
List
[
Callable
[[
float
,
float
],
bool
]]
=
[]
self
.
first_metric
=
''
self
.
first_metric
=
''
...
@@ -327,7 +331,6 @@ class _EarlyStoppingCallback:
...
@@ -327,7 +331,6 @@ class _EarlyStoppingCallback:
self
.
first_metric
=
env
.
evaluation_result_list
[
0
][
1
].
split
(
" "
)[
-
1
]
self
.
first_metric
=
env
.
evaluation_result_list
[
0
][
1
].
split
(
" "
)[
-
1
]
for
eval_ret
,
delta
in
zip
(
env
.
evaluation_result_list
,
deltas
):
for
eval_ret
,
delta
in
zip
(
env
.
evaluation_result_list
,
deltas
):
self
.
best_iter
.
append
(
0
)
self
.
best_iter
.
append
(
0
)
self
.
best_score_list
.
append
(
None
)
if
eval_ret
[
3
]:
# greater is better
if
eval_ret
[
3
]:
# greater is better
self
.
best_score
.
append
(
float
(
'-inf'
))
self
.
best_score
.
append
(
float
(
'-inf'
))
self
.
cmp_op
.
append
(
partial
(
self
.
_gt_delta
,
delta
=
delta
))
self
.
cmp_op
.
append
(
partial
(
self
.
_gt_delta
,
delta
=
delta
))
...
@@ -350,11 +353,16 @@ class _EarlyStoppingCallback:
...
@@ -350,11 +353,16 @@ class _EarlyStoppingCallback:
self
.
_init
(
env
)
self
.
_init
(
env
)
if
not
self
.
enabled
:
if
not
self
.
enabled
:
return
return
# self.best_score_list is initialized to an empty list
first_time_updating_best_score_list
=
(
self
.
best_score_list
==
[])
for
i
in
range
(
len
(
env
.
evaluation_result_list
)):
for
i
in
range
(
len
(
env
.
evaluation_result_list
)):
score
=
env
.
evaluation_result_list
[
i
][
2
]
score
=
env
.
evaluation_result_list
[
i
][
2
]
if
self
.
best_score_list
[
i
]
is
None
or
self
.
cmp_op
[
i
](
score
,
self
.
best_score
[
i
]):
if
first_time_updating_
best_score_list
or
self
.
cmp_op
[
i
](
score
,
self
.
best_score
[
i
]):
self
.
best_score
[
i
]
=
score
self
.
best_score
[
i
]
=
score
self
.
best_iter
[
i
]
=
env
.
iteration
self
.
best_iter
[
i
]
=
env
.
iteration
if
first_time_updating_best_score_list
:
self
.
best_score_list
.
append
(
env
.
evaluation_result_list
)
else
:
self
.
best_score_list
[
i
]
=
env
.
evaluation_result_list
self
.
best_score_list
[
i
]
=
env
.
evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted
=
env
.
evaluation_result_list
[
i
][
1
].
split
(
" "
)
eval_name_splitted
=
env
.
evaluation_result_list
[
i
][
1
].
split
(
" "
)
...
...
python-package/lightgbm/engine.py
View file @
5f79626f
...
@@ -11,8 +11,9 @@ import numpy as np
...
@@ -11,8 +11,9 @@ import numpy as np
from
.
import
callback
from
.
import
callback
from
.basic
import
(
Booster
,
Dataset
,
LightGBMError
,
_choose_param_value
,
_ConfigAliases
,
_InnerPredictor
,
from
.basic
import
(
Booster
,
Dataset
,
LightGBMError
,
_choose_param_value
,
_ConfigAliases
,
_InnerPredictor
,
_LGBM_CategoricalFeatureConfiguration
,
_LGBM_CustomObjectiveFunction
,
_LGBM_EvalFunctionResultType
,
_LGBM_BoosterEvalMethodResultType
,
_LGBM_CategoricalFeatureConfiguration
,
_LGBM_FeatureNameConfiguration
,
_log_warning
)
_LGBM_CustomObjectiveFunction
,
_LGBM_EvalFunctionResultType
,
_LGBM_FeatureNameConfiguration
,
_log_warning
)
from
.compat
import
SKLEARN_INSTALLED
,
_LGBMBaseCrossValidator
,
_LGBMGroupKFold
,
_LGBMStratifiedKFold
from
.compat
import
SKLEARN_INSTALLED
,
_LGBMBaseCrossValidator
,
_LGBMGroupKFold
,
_LGBMStratifiedKFold
__all__
=
[
__all__
=
[
...
@@ -257,7 +258,7 @@ def train(
...
@@ -257,7 +258,7 @@ def train(
booster
.
update
(
fobj
=
fobj
)
booster
.
update
(
fobj
=
fobj
)
evaluation_result_list
=
[]
evaluation_result_list
:
List
[
_LGBM_BoosterEvalMethodResultType
]
=
[]
# check evaluation result.
# check evaluation result.
if
valid_sets
is
not
None
:
if
valid_sets
is
not
None
:
if
is_valid_contain_train
:
if
is_valid_contain_train
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment