Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
4ea170f3
Unverified
Commit
4ea170f3
authored
Aug 21, 2023
by
James Lamb
Committed by
GitHub
Aug 21, 2023
Browse files
[python-package] use dataclass for CallbackEnv (#6048)
parent
5fe84f8f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
16 deletions
+22
-16
.ci/test-python-oldest.sh
.ci/test-python-oldest.sh
+1
-0
python-package/lightgbm/callback.py
python-package/lightgbm/callback.py
+16
-12
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+4
-4
python-package/pyproject.toml
python-package/pyproject.toml
+1
-0
No files found.
.ci/test-python-oldest.sh
View file @
4ea170f3
...
...
@@ -7,6 +7,7 @@
#
echo
"installing lightgbm's dependencies"
pip
install
\
'dataclasses'
\
'numpy==1.12.0'
\
'pandas==0.24.0'
\
'scikit-learn==0.18.2'
\
...
...
python-package/lightgbm/callback.py
View file @
4ea170f3
# coding: utf-8
"""Callbacks library."""
import
collections
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
.basic
import
_ConfigAliases
,
_LGBM_BoosterEvalMethodResultType
,
_log_info
,
_log_warning
from
.basic
import
Booster
,
_ConfigAliases
,
_LGBM_BoosterEvalMethodResultType
,
_log_info
,
_log_warning
if
TYPE_CHECKING
:
from
.engine
import
CVBooster
__all__
=
[
'early_stopping'
,
...
...
@@ -43,14 +47,14 @@ class EarlyStopException(Exception):
# Callback environment used by callbacks
CallbackEnv
=
collections
.
namedtuple
(
"
CallbackEnv
"
,
[
"
model
"
,
"
params
"
,
"
iteration
"
,
"
begin_iteration
"
,
"
end_iteration
"
,
"
evaluation_result_list
"
])
@
dataclass
class
CallbackEnv
:
model
:
Union
[
Booster
,
"CVBooster"
]
params
:
Dict
[
str
,
Any
]
iteration
:
int
begin_iteration
:
int
end_iteration
:
int
evaluation_result_list
:
Optional
[
List
[
_LGBM_BoosterEvalMethodResultType
]]
def
_format_eval_result
(
value
:
_EvalResultTuple
,
show_stdv
:
bool
)
->
str
:
...
...
@@ -126,7 +130,7 @@ class _RecordEvaluationCallback:
data_name
,
eval_name
=
item
[:
2
]
else
:
# cv
data_name
,
eval_name
=
item
[
1
].
split
()
self
.
eval_result
.
setdefault
(
data_name
,
collections
.
OrderedDict
())
self
.
eval_result
.
setdefault
(
data_name
,
OrderedDict
())
if
len
(
item
)
==
4
:
self
.
eval_result
[
data_name
].
setdefault
(
eval_name
,
[])
else
:
...
...
python-package/lightgbm/engine.py
View file @
4ea170f3
# coding: utf-8
"""Library with training routines of LightGBM."""
import
collections
import
copy
import
json
from
collections
import
OrderedDict
,
defaultdict
from
operator
import
attrgetter
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
...
...
@@ -293,7 +293,7 @@ def train(
booster
.
best_iteration
=
earlyStopException
.
best_iteration
+
1
evaluation_result_list
=
earlyStopException
.
best_score
break
booster
.
best_score
=
collections
.
defaultdict
(
collections
.
OrderedDict
)
booster
.
best_score
=
defaultdict
(
OrderedDict
)
for
dataset_name
,
eval_name
,
score
,
_
in
evaluation_result_list
:
booster
.
best_score
[
dataset_name
][
eval_name
]
=
score
if
not
keep_training_booster
:
...
...
@@ -526,7 +526,7 @@ def _agg_cv_result(
raw_results
:
List
[
List
[
Tuple
[
str
,
str
,
float
,
bool
]]]
)
->
List
[
Tuple
[
str
,
str
,
float
,
bool
,
float
]]:
"""Aggregate cross-validation results."""
cvmap
:
Dict
[
str
,
List
[
float
]]
=
collections
.
OrderedDict
()
cvmap
:
Dict
[
str
,
List
[
float
]]
=
OrderedDict
()
metric_type
:
Dict
[
str
,
bool
]
=
{}
for
one_result
in
raw_results
:
for
one_line
in
one_result
:
...
...
@@ -717,7 +717,7 @@ def cv(
.
set_feature_name
(
feature_name
)
\
.
set_categorical_feature
(
categorical_feature
)
results
=
collections
.
defaultdict
(
list
)
results
=
defaultdict
(
list
)
cvfolds
=
_make_n_folds
(
full_data
=
train_set
,
folds
=
folds
,
nfold
=
nfold
,
params
=
params
,
seed
=
seed
,
fpreproc
=
fpreproc
,
stratified
=
stratified
,
shuffle
=
shuffle
,
...
...
python-package/pyproject.toml
View file @
4ea170f3
...
...
@@ -18,6 +18,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies
=
[
"dataclasses ; python_version < '3.7'"
,
"numpy"
,
"scipy"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment