"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "494fef34eb390db59dafdaa964baac112cb08949"
Unverified Commit 4ea170f3 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] use dataclass for CallbackEnv (#6048)

parent 5fe84f8f
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# #
echo "installing lightgbm's dependencies" echo "installing lightgbm's dependencies"
pip install \ pip install \
'dataclasses' \
'numpy==1.12.0' \ 'numpy==1.12.0' \
'pandas==0.24.0' \ 'pandas==0.24.0' \
'scikit-learn==0.18.2' \ 'scikit-learn==0.18.2' \
......
# coding: utf-8 # coding: utf-8
"""Callbacks library.""" """Callbacks library."""
import collections from collections import OrderedDict
from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
if TYPE_CHECKING:
from .engine import CVBooster
__all__ = [ __all__ = [
'early_stopping', 'early_stopping',
...@@ -43,14 +47,14 @@ class EarlyStopException(Exception): ...@@ -43,14 +47,14 @@ class EarlyStopException(Exception):
# Callback environment used by callbacks # Callback environment used by callbacks
CallbackEnv = collections.namedtuple( @dataclass
"CallbackEnv", class CallbackEnv:
["model", model: Union[Booster, "CVBooster"]
"params", params: Dict[str, Any]
"iteration", iteration: int
"begin_iteration", begin_iteration: int
"end_iteration", end_iteration: int
"evaluation_result_list"]) evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str: def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
...@@ -126,7 +130,7 @@ class _RecordEvaluationCallback: ...@@ -126,7 +130,7 @@ class _RecordEvaluationCallback:
data_name, eval_name = item[:2] data_name, eval_name = item[:2]
else: # cv else: # cv
data_name, eval_name = item[1].split() data_name, eval_name = item[1].split()
self.eval_result.setdefault(data_name, collections.OrderedDict()) self.eval_result.setdefault(data_name, OrderedDict())
if len(item) == 4: if len(item) == 4:
self.eval_result[data_name].setdefault(eval_name, []) self.eval_result[data_name].setdefault(eval_name, [])
else: else:
......
# coding: utf-8 # coding: utf-8
"""Library with training routines of LightGBM.""" """Library with training routines of LightGBM."""
import collections
import copy import copy
import json import json
from collections import OrderedDict, defaultdict
from operator import attrgetter from operator import attrgetter
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
...@@ -293,7 +293,7 @@ def train( ...@@ -293,7 +293,7 @@ def train(
booster.best_iteration = earlyStopException.best_iteration + 1 booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score evaluation_result_list = earlyStopException.best_score
break break
booster.best_score = collections.defaultdict(collections.OrderedDict) booster.best_score = defaultdict(OrderedDict)
for dataset_name, eval_name, score, _ in evaluation_result_list: for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster: if not keep_training_booster:
...@@ -526,7 +526,7 @@ def _agg_cv_result( ...@@ -526,7 +526,7 @@ def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]] raw_results: List[List[Tuple[str, str, float, bool]]]
) -> List[Tuple[str, str, float, bool, float]]: ) -> List[Tuple[str, str, float, bool, float]]:
"""Aggregate cross-validation results.""" """Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = collections.OrderedDict() cvmap: Dict[str, List[float]] = OrderedDict()
metric_type: Dict[str, bool] = {} metric_type: Dict[str, bool] = {}
for one_result in raw_results: for one_result in raw_results:
for one_line in one_result: for one_line in one_result:
...@@ -717,7 +717,7 @@ def cv( ...@@ -717,7 +717,7 @@ def cv(
.set_feature_name(feature_name) \ .set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature) .set_categorical_feature(categorical_feature)
results = collections.defaultdict(list) results = defaultdict(list)
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold, cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold,
params=params, seed=seed, fpreproc=fpreproc, params=params, seed=seed, fpreproc=fpreproc,
stratified=stratified, shuffle=shuffle, stratified=stratified, shuffle=shuffle,
......
...@@ -18,6 +18,7 @@ classifiers = [ ...@@ -18,6 +18,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence" "Topic :: Scientific/Engineering :: Artificial Intelligence"
] ]
dependencies = [ dependencies = [
"dataclasses ; python_version < '3.7'",
"numpy", "numpy",
"scipy" "scipy"
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment