"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "0e576575852fa543bea00056a2801b247edc8283"
Unverified Commit 7c9a985a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors in Dataset construction (#6106)

parent fe7f8fe6
...@@ -24,6 +24,13 @@ from .libpath import find_lib_path ...@@ -24,6 +24,13 @@ from .libpath import find_lib_path
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Literal from typing import Literal
# typing.TypeGuard was only introduced in Python 3.10
try:
from typing import TypeGuard
except ImportError:
from typing_extensions import TypeGuard
__all__ = [ __all__ = [
'Booster', 'Booster',
'Dataset', 'Dataset',
...@@ -279,6 +286,20 @@ def _is_1d_list(data: Any) -> bool: ...@@ -279,6 +286,20 @@ def _is_1d_list(data: Any) -> bool:
return isinstance(data, list) and (not data or _is_numeric(data[0])) return isinstance(data, list) and (not data or _is_numeric(data[0]))
def _is_list_of_numpy_arrays(data: Any) -> "TypeGuard[List[np.ndarray]]":
return (
isinstance(data, list)
and all(isinstance(x, np.ndarray) for x in data)
)
def _is_list_of_sequences(data: Any) -> "TypeGuard[List[Sequence]]":
return (
isinstance(data, list)
and all(isinstance(x, Sequence) for x in data)
)
def _is_1d_collection(data: Any) -> bool: def _is_1d_collection(data: Any) -> bool:
"""Check whether data is a 1-D collection.""" """Check whether data is a 1-D collection."""
return ( return (
...@@ -1918,9 +1939,9 @@ class Dataset: ...@@ -1918,9 +1939,9 @@ class Dataset:
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
self.__init_from_np2d(data, params_str, ref_dataset) self.__init_from_np2d(data, params_str, ref_dataset)
elif isinstance(data, list) and len(data) > 0: elif isinstance(data, list) and len(data) > 0:
if all(isinstance(x, np.ndarray) for x in data): if _is_list_of_numpy_arrays(data):
self.__init_from_list_np2d(data, params_str, ref_dataset) self.__init_from_list_np2d(data, params_str, ref_dataset)
elif all(isinstance(x, Sequence) for x in data): elif _is_list_of_sequences(data):
self.__init_from_seqs(data, ref_dataset) self.__init_from_seqs(data, ref_dataset)
else: else:
raise TypeError('Data list can only be of ndarray or Sequence') raise TypeError('Data list can only be of ndarray or Sequence')
...@@ -2870,7 +2891,7 @@ class Dataset: ...@@ -2870,7 +2891,7 @@ class Dataset:
self.data = self.data[self.used_indices, :] self.data = self.data[self.used_indices, :]
elif isinstance(self.data, Sequence): elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices] self.data = self.data[self.used_indices]
elif isinstance(self.data, list) and len(self.data) > 0 and all(isinstance(x, Sequence) for x in self.data): elif _is_list_of_sequences(self.data) and len(self.data) > 0:
self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices))) self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
else: else:
_log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n" _log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment