Unverified Commit b27d81ea authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[ci] [python-package] check for untyped definitions with mypy (#6339)

parent 1a292f89
...@@ -74,6 +74,7 @@ if [[ $TASK == "lint" ]]; then ...@@ -74,6 +74,7 @@ if [[ $TASK == "lint" ]]; then
${CONDA_PYTHON_REQUIREMENT} \ ${CONDA_PYTHON_REQUIREMENT} \
cmakelint \ cmakelint \
cpplint \ cpplint \
'matplotlib>=3.8.3' \
mypy \ mypy \
'pre-commit>=3.6.0' \ 'pre-commit>=3.6.0' \
'pyarrow>=14.0' \ 'pyarrow>=14.0' \
......
...@@ -13,7 +13,7 @@ from os import SEEK_END, environ ...@@ -13,7 +13,7 @@ from os import SEEK_END, environ
from os.path import getsize from os.path import getsize
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
...@@ -537,13 +537,13 @@ def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str: ...@@ -537,13 +537,13 @@ def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str:
class _TempFile: class _TempFile:
"""Proxy class to workaround errors on Windows.""" """Proxy class to workaround errors on Windows."""
def __enter__(self): def __enter__(self) -> "_TempFile":
with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f: with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f:
self.name = f.name self.name = f.name
self.path = Path(self.name) self.path = Path(self.name)
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.path.is_file(): if self.path.is_file():
self.path.unlink() self.path.unlink()
...@@ -595,7 +595,7 @@ class _ConfigAliases: ...@@ -595,7 +595,7 @@ class _ConfigAliases:
) )
@classmethod @classmethod
def get(cls, *args) -> Set[str]: def get(cls, *args: str) -> Set[str]:
if cls.aliases is None: if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases() cls.aliases = cls._get_all_param_aliases()
ret = set() ret = set()
...@@ -610,7 +610,7 @@ class _ConfigAliases: ...@@ -610,7 +610,7 @@ class _ConfigAliases:
return cls.aliases.get(name, [name]) return cls.aliases.get(name, [name])
@classmethod @classmethod
def get_by_alias(cls, *args) -> Set[str]: def get_by_alias(cls, *args: str) -> Set[str]:
if cls.aliases is None: if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases() cls.aliases = cls._get_all_param_aliases()
ret = set(args) ret = set(args)
...@@ -1563,7 +1563,7 @@ class _InnerPredictor: ...@@ -1563,7 +1563,7 @@ class _InnerPredictor:
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int, predict_type: int,
): ) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]:
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr) ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = _c_float_array(csc.data) ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
csc_indices = csc.indices.astype(np.int32, copy=False) csc_indices = csc.indices.astype(np.int32, copy=False)
...@@ -1813,7 +1813,7 @@ class Dataset: ...@@ -1813,7 +1813,7 @@ class Dataset:
self._need_slice = True self._need_slice = True
self._predictor: Optional[_InnerPredictor] = None self._predictor: Optional[_InnerPredictor] = None
self.pandas_categorical: Optional[List[List]] = None self.pandas_categorical: Optional[List[List]] = None
self._params_back_up = None self._params_back_up: Optional[Dict[str, Any]] = None
self.version = 0 self.version = 0
self._start_row = 0 # Used when pushing rows one by one. self._start_row = 0 # Used when pushing rows one by one.
...@@ -2195,7 +2195,7 @@ class Dataset: ...@@ -2195,7 +2195,7 @@ class Dataset:
return self.set_feature_name(feature_name) return self.set_feature_name(feature_name)
@staticmethod @staticmethod
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]): def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]:
offset = 0 offset = 0
seq_id = 0 seq_id = 0
seq = seqs[seq_id] seq = seqs[seq_id]
...@@ -2697,7 +2697,7 @@ class Dataset: ...@@ -2697,7 +2697,7 @@ class Dataset:
return self return self
params = deepcopy(params) params = deepcopy(params)
def update(): def update() -> None:
if not self.params: if not self.params:
self.params = params self.params = params
else: else:
...@@ -3704,7 +3704,7 @@ class Booster: ...@@ -3704,7 +3704,7 @@ class Booster:
def __copy__(self) -> "Booster": def __copy__(self) -> "Booster":
return self.__deepcopy__(None) return self.__deepcopy__(None)
def __deepcopy__(self, _) -> "Booster": def __deepcopy__(self, *args: Any, **kwargs: Any) -> "Booster":
model_str = self.model_to_string(num_iteration=-1) model_str = self.model_to_string(num_iteration=-1)
return Booster(model_str=model_str) return Booster(model_str=model_str)
...@@ -4757,7 +4757,7 @@ class Booster: ...@@ -4757,7 +4757,7 @@ class Booster:
dataset_params: Optional[Dict[str, Any]] = None, dataset_params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True, free_raw_data: bool = True,
validate_features: bool = False, validate_features: bool = False,
**kwargs, **kwargs: Any,
) -> "Booster": ) -> "Booster":
"""Refit the existing Booster by new data. """Refit the existing Booster by new data.
......
# coding: utf-8 # coding: utf-8
"""Compatibility library.""" """Compatibility library."""
from typing import List from typing import Any, List
"""pandas""" """pandas"""
try: try:
...@@ -20,19 +20,19 @@ except ImportError: ...@@ -20,19 +20,19 @@ except ImportError:
class pd_Series: # type: ignore class pd_Series: # type: ignore
"""Dummy class for pandas.Series.""" """Dummy class for pandas.Series."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class pd_DataFrame: # type: ignore class pd_DataFrame: # type: ignore
"""Dummy class for pandas.DataFrame.""" """Dummy class for pandas.DataFrame."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class pd_CategoricalDtype: # type: ignore class pd_CategoricalDtype: # type: ignore
"""Dummy class for pandas.CategoricalDtype.""" """Dummy class for pandas.CategoricalDtype."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
concat = None concat = None
...@@ -45,7 +45,7 @@ except ImportError: ...@@ -45,7 +45,7 @@ except ImportError:
class np_random_Generator: # type: ignore class np_random_Generator: # type: ignore
"""Dummy class for np.random.Generator.""" """Dummy class for np.random.Generator."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
...@@ -80,7 +80,7 @@ except ImportError: ...@@ -80,7 +80,7 @@ except ImportError:
class dt_DataTable: # type: ignore class dt_DataTable: # type: ignore
"""Dummy class for datatable.DataTable.""" """Dummy class for datatable.DataTable."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
...@@ -104,7 +104,7 @@ try: ...@@ -104,7 +104,7 @@ try:
from sklearn.utils.validation import check_consistent_length from sklearn.utils.validation import check_consistent_length
# dummy function to support older version of scikit-learn # dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight, X, dtype=None): def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X) check_consistent_length(sample_weight, X)
return sample_weight return sample_weight
...@@ -176,31 +176,31 @@ except ImportError: ...@@ -176,31 +176,31 @@ except ImportError:
class Client: # type: ignore class Client: # type: ignore
"""Dummy class for dask.distributed.Client.""" """Dummy class for dask.distributed.Client."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class Future: # type: ignore class Future: # type: ignore
"""Dummy class for dask.distributed.Future.""" """Dummy class for dask.distributed.Future."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class dask_Array: # type: ignore class dask_Array: # type: ignore
"""Dummy class for dask.array.Array.""" """Dummy class for dask.array.Array."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class dask_DataFrame: # type: ignore class dask_DataFrame: # type: ignore
"""Dummy class for dask.dataframe.DataFrame.""" """Dummy class for dask.dataframe.DataFrame."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class dask_Series: # type: ignore class dask_Series: # type: ignore
"""Dummy class for dask.dataframe.Series.""" """Dummy class for dask.dataframe.Series."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
...@@ -222,19 +222,19 @@ except ImportError: ...@@ -222,19 +222,19 @@ except ImportError:
class pa_Array: # type: ignore class pa_Array: # type: ignore
"""Dummy class for pa.Array.""" """Dummy class for pa.Array."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class pa_ChunkedArray: # type: ignore class pa_ChunkedArray: # type: ignore
"""Dummy class for pa.ChunkedArray.""" """Dummy class for pa.ChunkedArray."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class pa_Table: # type: ignore class pa_Table: # type: ignore
"""Dummy class for pa.Table.""" """Dummy class for pa.Table."""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class arrow_cffi: # type: ignore class arrow_cffi: # type: ignore
...@@ -245,7 +245,7 @@ except ImportError: ...@@ -245,7 +245,7 @@ except ImportError:
cast = None cast = None
new = None new = None
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
pass pass
class pa_compute: # type: ignore class pa_compute: # type: ignore
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math import math
from copy import deepcopy from copy import deepcopy
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -19,6 +19,9 @@ __all__ = [ ...@@ -19,6 +19,9 @@ __all__ = [
"plot_tree", "plot_tree",
] ]
if TYPE_CHECKING:
import matplotlib
def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None: def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None:
"""Check object is not tuple or does not have 2 elements.""" """Check object is not tuple or does not have 2 elements."""
...@@ -32,7 +35,7 @@ def _float2str(value: float, precision: Optional[int]) -> str: ...@@ -32,7 +35,7 @@ def _float2str(value: float, precision: Optional[int]) -> str:
def plot_importance( def plot_importance(
booster: Union[Booster, LGBMModel], booster: Union[Booster, LGBMModel],
ax=None, ax: "Optional[matplotlib.axes.Axes]" = None,
height: float = 0.2, height: float = 0.2,
xlim: Optional[Tuple[float, float]] = None, xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None,
...@@ -168,7 +171,7 @@ def plot_split_value_histogram( ...@@ -168,7 +171,7 @@ def plot_split_value_histogram(
booster: Union[Booster, LGBMModel], booster: Union[Booster, LGBMModel],
feature: Union[int, str], feature: Union[int, str],
bins: Union[int, str, None] = None, bins: Union[int, str, None] = None,
ax=None, ax: "Optional[matplotlib.axes.Axes]" = None,
width_coef: float = 0.8, width_coef: float = 0.8,
xlim: Optional[Tuple[float, float]] = None, xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None,
...@@ -284,7 +287,7 @@ def plot_metric( ...@@ -284,7 +287,7 @@ def plot_metric(
booster: Union[Dict, LGBMModel], booster: Union[Dict, LGBMModel],
metric: Optional[str] = None, metric: Optional[str] = None,
dataset_names: Optional[List[str]] = None, dataset_names: Optional[List[str]] = None,
ax=None, ax: "Optional[matplotlib.axes.Axes]" = None,
xlim: Optional[Tuple[float, float]] = None, xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = "Metric during training", title: Optional[str] = "Metric during training",
...@@ -735,7 +738,7 @@ def create_tree_digraph( ...@@ -735,7 +738,7 @@ def create_tree_digraph(
def plot_tree( def plot_tree(
booster: Union[Booster, LGBMModel], booster: Union[Booster, LGBMModel],
ax=None, ax: "Optional[matplotlib.axes.Axes]" = None,
tree_index: int = 0, tree_index: int = 0,
figsize: Optional[Tuple[float, float]] = None, figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None, dpi: Optional[int] = None,
......
...@@ -478,7 +478,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -478,7 +478,7 @@ class LGBMModel(_LGBMModelBase):
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None, random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = "split", importance_type: str = "split",
**kwargs, **kwargs: Any,
): ):
r"""Construct a gradient boosting model. r"""Construct a gradient boosting model.
......
...@@ -92,6 +92,7 @@ skip_glob = [ ...@@ -92,6 +92,7 @@ skip_glob = [
] ]
[tool.mypy] [tool.mypy]
disallow_untyped_defs = true
exclude = 'build/*|compile/*|docs/*|examples/*|external_libs/*|lightgbm-python/*|tests/*' exclude = 'build/*|compile/*|docs/*|examples/*|external_libs/*|lightgbm-python/*|tests/*'
ignore_missing_imports = true ignore_missing_imports = true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment