Unverified Commit ac37bf8a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[ci] [python-package] fix mypy error about return_cvbooster in cv() (#5845)

parent f74875ed
...@@ -545,7 +545,7 @@ def cv( ...@@ -545,7 +545,7 @@ def cv(
callbacks: Optional[List[Callable]] = None, callbacks: Optional[List[Callable]] = None,
eval_train_metric: bool = False, eval_train_metric: bool = False,
return_cvbooster: bool = False return_cvbooster: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Union[List[float], CVBooster]]:
"""Perform the cross-validation with given parameters. """Perform the cross-validation with given parameters.
Parameters Parameters
...@@ -651,7 +651,7 @@ def cv( ...@@ -651,7 +651,7 @@ def cv(
{'metric1-mean': [values], 'metric1-stdv': [values], {'metric1-mean': [values], 'metric1-stdv': [values],
'metric2-mean': [values], 'metric2-stdv': [values], 'metric2-mean': [values], 'metric2-stdv': [values],
...}. ...}.
If ``return_cvbooster=True``, also returns trained boosters via ``cvbooster`` key. If ``return_cvbooster=True``, also returns trained boosters wrapped in a ``CVBooster`` object via ``cvbooster`` key.
""" """
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError(f"cv() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.") raise TypeError(f"cv() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.")
...@@ -763,6 +763,6 @@ def cv( ...@@ -763,6 +763,6 @@ def cv(
break break
if return_cvbooster: if return_cvbooster:
results['cvbooster'] = cvfolds results['cvbooster'] = cvfolds # type: ignore[assignment]
return dict(results) return dict(results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment