Unverified Commit b27d81ea authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[ci] [python-package] check for untyped definitions with mypy (#6339)

parent 1a292f89
......@@ -74,6 +74,7 @@ if [[ $TASK == "lint" ]]; then
${CONDA_PYTHON_REQUIREMENT} \
cmakelint \
cpplint \
'matplotlib>=3.8.3' \
mypy \
'pre-commit>=3.6.0' \
'pyarrow>=14.0' \
......
......@@ -13,7 +13,7 @@ from os import SEEK_END, environ
from os.path import getsize
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
import numpy as np
import scipy.sparse
......@@ -537,13 +537,13 @@ def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str:
class _TempFile:
"""Proxy class to workaround errors on Windows."""
def __enter__(self):
def __enter__(self) -> "_TempFile":
with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f:
self.name = f.name
self.path = Path(self.name)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.path.is_file():
self.path.unlink()
......@@ -595,7 +595,7 @@ class _ConfigAliases:
)
@classmethod
def get(cls, *args) -> Set[str]:
def get(cls, *args: str) -> Set[str]:
if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases()
ret = set()
......@@ -610,7 +610,7 @@ class _ConfigAliases:
return cls.aliases.get(name, [name])
@classmethod
def get_by_alias(cls, *args) -> Set[str]:
def get_by_alias(cls, *args: str) -> Set[str]:
if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases()
ret = set(args)
......@@ -1563,7 +1563,7 @@ class _InnerPredictor:
start_iteration: int,
num_iteration: int,
predict_type: int,
):
) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]:
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
csc_indices = csc.indices.astype(np.int32, copy=False)
......@@ -1813,7 +1813,7 @@ class Dataset:
self._need_slice = True
self._predictor: Optional[_InnerPredictor] = None
self.pandas_categorical: Optional[List[List]] = None
self._params_back_up = None
self._params_back_up: Optional[Dict[str, Any]] = None
self.version = 0
self._start_row = 0 # Used when pushing rows one by one.
......@@ -2195,7 +2195,7 @@ class Dataset:
return self.set_feature_name(feature_name)
@staticmethod
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]):
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]:
offset = 0
seq_id = 0
seq = seqs[seq_id]
......@@ -2697,7 +2697,7 @@ class Dataset:
return self
params = deepcopy(params)
def update():
def update() -> None:
if not self.params:
self.params = params
else:
......@@ -3704,7 +3704,7 @@ class Booster:
def __copy__(self) -> "Booster":
return self.__deepcopy__(None)
def __deepcopy__(self, _) -> "Booster":
def __deepcopy__(self, *args: Any, **kwargs: Any) -> "Booster":
model_str = self.model_to_string(num_iteration=-1)
return Booster(model_str=model_str)
......@@ -4757,7 +4757,7 @@ class Booster:
dataset_params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True,
validate_features: bool = False,
**kwargs,
**kwargs: Any,
) -> "Booster":
"""Refit the existing Booster by new data.
......
# coding: utf-8
"""Compatibility library."""
from typing import List
from typing import Any, List
"""pandas"""
try:
......@@ -20,19 +20,19 @@ except ImportError:
class pd_Series: # type: ignore
"""Dummy class for pandas.Series."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class pd_DataFrame: # type: ignore
"""Dummy class for pandas.DataFrame."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class pd_CategoricalDtype: # type: ignore
"""Dummy class for pandas.CategoricalDtype."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
concat = None
......@@ -45,7 +45,7 @@ except ImportError:
class np_random_Generator: # type: ignore
"""Dummy class for np.random.Generator."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
......@@ -80,7 +80,7 @@ except ImportError:
class dt_DataTable: # type: ignore
"""Dummy class for datatable.DataTable."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
......@@ -104,7 +104,7 @@ try:
from sklearn.utils.validation import check_consistent_length
# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight, X, dtype=None):
def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight
......@@ -176,31 +176,31 @@ except ImportError:
class Client: # type: ignore
"""Dummy class for dask.distributed.Client."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class Future: # type: ignore
"""Dummy class for dask.distributed.Future."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_Array: # type: ignore
"""Dummy class for dask.array.Array."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_DataFrame: # type: ignore
"""Dummy class for dask.dataframe.DataFrame."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_Series: # type: ignore
"""Dummy class for dask.dataframe.Series."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
......@@ -222,19 +222,19 @@ except ImportError:
class pa_Array: # type: ignore
"""Dummy class for pa.Array."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_ChunkedArray: # type: ignore
"""Dummy class for pa.ChunkedArray."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_Table: # type: ignore
"""Dummy class for pa.Table."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class arrow_cffi: # type: ignore
......@@ -245,7 +245,7 @@ except ImportError:
cast = None
new = None
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_compute: # type: ignore
......
......@@ -3,7 +3,7 @@
import math
from copy import deepcopy
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
......@@ -19,6 +19,9 @@ __all__ = [
"plot_tree",
]
if TYPE_CHECKING:
import matplotlib
def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None:
"""Check object is not tuple or does not have 2 elements."""
......@@ -32,7 +35,7 @@ def _float2str(value: float, precision: Optional[int]) -> str:
def plot_importance(
booster: Union[Booster, LGBMModel],
ax=None,
ax: "Optional[matplotlib.axes.Axes]" = None,
height: float = 0.2,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
......@@ -168,7 +171,7 @@ def plot_split_value_histogram(
booster: Union[Booster, LGBMModel],
feature: Union[int, str],
bins: Union[int, str, None] = None,
ax=None,
ax: "Optional[matplotlib.axes.Axes]" = None,
width_coef: float = 0.8,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
......@@ -284,7 +287,7 @@ def plot_metric(
booster: Union[Dict, LGBMModel],
metric: Optional[str] = None,
dataset_names: Optional[List[str]] = None,
ax=None,
ax: "Optional[matplotlib.axes.Axes]" = None,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = "Metric during training",
......@@ -735,7 +738,7 @@ def create_tree_digraph(
def plot_tree(
booster: Union[Booster, LGBMModel],
ax=None,
ax: "Optional[matplotlib.axes.Axes]" = None,
tree_index: int = 0,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
......
......@@ -478,7 +478,7 @@ class LGBMModel(_LGBMModelBase):
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
**kwargs,
**kwargs: Any,
):
r"""Construct a gradient boosting model.
......
......@@ -92,6 +92,7 @@ skip_glob = [
]
[tool.mypy]
disallow_untyped_defs = true
exclude = 'build/*|compile/*|docs/*|examples/*|external_libs/*|lightgbm-python/*|tests/*'
ignore_missing_imports = true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment