"tests/vscode:/vscode.git/clone" did not exist on "ddda85b06061fe11ce42e1f636c101b651a7ef19"
Unverified Commit 13fa6d95 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python] add type hints on train() in engine.py (#4544)



* [python] add type hints on train() in engine.py

* revert dask.py and sklearn.py changes

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* update docs on evals_result contents

* Update python-package/lightgbm/engine.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent f5925c3f
......@@ -4,6 +4,7 @@ import collections
import copy
from operator import attrgetter
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -11,14 +12,34 @@ from . import callback
from .basic import Booster, Dataset, LightGBMError, _ConfigAliases, _InnerPredictor, _log_warning
from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold
def train(params, train_set, num_boost_round=100,
valid_sets=None, valid_names=None,
fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None,
keep_training_booster=False, callbacks=None):
_LGBM_CustomObjectiveFunction = Callable[
[Union[List, np.ndarray], Dataset],
Tuple[Union[List, np.ndarray], Union[List, np.ndarray]]
]
_LGBM_CustomMetricFunction = Callable[
[Union[List, np.ndarray], Dataset],
Tuple[str, float, bool]
]
def train(
params: Dict[str, Any],
train_set: Dataset,
num_boost_round: int = 100,
valid_sets: Optional[List[Dataset]] = None,
valid_names: Optional[List[str]] = None,
fobj: Optional[_LGBM_CustomObjectiveFunction] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[List[str], str] = 'auto',
categorical_feature: Union[List[str], List[int], str] = 'auto',
early_stopping_rounds: Optional[int] = None,
evals_result: Optional[Dict[str, Any]] = None,
verbose_eval: Union[bool, int] = True,
learning_rates: Optional[Union[List[float], Callable[[int], float]]] = None,
keep_training_booster: bool = False,
callbacks: Optional[List[Callable]] = None
) -> Booster:
"""Perform the training with given parameters.
Parameters
......@@ -101,7 +122,9 @@ def train(params, train_set, num_boost_round=100,
The index of iteration that has the best performance will be saved in the ``best_iteration`` field
if early stopping logic is enabled by setting ``early_stopping_rounds``.
evals_result: dict or None, optional (default=None)
This dictionary used to store all evaluation results of all the items in ``valid_sets``.
Dictionary used to store all evaluation results of all the items in ``valid_sets``.
This should be initialized outside of your call to ``train()`` and should be empty.
Any initial contents of the dictionary will be deleted by ``train()``.
.. rubric:: Example
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment