Unverified Commit 4ea170f3 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] use dataclass for CallbackEnv (#6048)

parent 5fe84f8f
......@@ -7,6 +7,7 @@
#
echo "installing lightgbm's dependencies"
pip install \
'dataclasses' \
'numpy==1.12.0' \
'pandas==0.24.0' \
'scikit-learn==0.18.2' \
......
# coding: utf-8
"""Callbacks library."""
import collections
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
if TYPE_CHECKING:
from .engine import CVBooster
__all__ = [
'early_stopping',
......@@ -43,14 +47,14 @@ class EarlyStopException(Exception):
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
"CallbackEnv",
["model",
"params",
"iteration",
"begin_iteration",
"end_iteration",
"evaluation_result_list"])
@dataclass
class CallbackEnv:
model: Union[Booster, "CVBooster"]
params: Dict[str, Any]
iteration: int
begin_iteration: int
end_iteration: int
evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
......@@ -126,7 +130,7 @@ class _RecordEvaluationCallback:
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
self.eval_result.setdefault(data_name, collections.OrderedDict())
self.eval_result.setdefault(data_name, OrderedDict())
if len(item) == 4:
self.eval_result[data_name].setdefault(eval_name, [])
else:
......
# coding: utf-8
"""Library with training routines of LightGBM."""
import collections
import copy
import json
from collections import OrderedDict, defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
......@@ -293,7 +293,7 @@ def train(
booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score
break
booster.best_score = collections.defaultdict(collections.OrderedDict)
booster.best_score = defaultdict(OrderedDict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster:
......@@ -526,7 +526,7 @@ def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]]
) -> List[Tuple[str, str, float, bool, float]]:
"""Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = collections.OrderedDict()
cvmap: Dict[str, List[float]] = OrderedDict()
metric_type: Dict[str, bool] = {}
for one_result in raw_results:
for one_line in one_result:
......@@ -717,7 +717,7 @@ def cv(
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)
results = collections.defaultdict(list)
results = defaultdict(list)
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold,
params=params, seed=seed, fpreproc=fpreproc,
stratified=stratified, shuffle=shuffle,
......
......@@ -18,6 +18,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
"dataclasses ; python_version < '3.7'",
"numpy",
"scipy"
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment