Unverified Commit 0cff4f83 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints for functions accepting dtypes (#5773)

parent 1585ee15
...@@ -240,7 +240,7 @@ def _is_numpy_column_array(data: Any) -> bool: ...@@ -240,7 +240,7 @@ def _is_numpy_column_array(data: Any) -> bool:
return len(shape) == 2 and shape[1] == 1 return len(shape) == 2 and shape[1] == 1
def _cast_numpy_array_to_dtype(array: np.ndarray, dtype: np.dtype) -> np.ndarray: def _cast_numpy_array_to_dtype(array: np.ndarray, dtype: "np.typing.DTypeLike") -> np.ndarray:
"""Cast numpy array to given dtype.""" """Cast numpy array to given dtype."""
if array.dtype == dtype: if array.dtype == dtype:
return array return array
...@@ -264,7 +264,7 @@ def _is_1d_collection(data: Any) -> bool: ...@@ -264,7 +264,7 @@ def _is_1d_collection(data: Any) -> bool:
def _list_to_1d_numpy( def _list_to_1d_numpy(
data: Any, data: Any,
dtype=np.float32, dtype: "np.typing.DTypeLike" = np.float32,
name: str = 'list' name: str = 'list'
) -> np.ndarray: ) -> np.ndarray:
"""Convert data to numpy 1-D array.""" """Convert data to numpy 1-D array."""
...@@ -303,7 +303,11 @@ def _is_2d_collection(data: Any) -> bool: ...@@ -303,7 +303,11 @@ def _is_2d_collection(data: Any) -> bool:
) )
def _data_to_2d_numpy(data: Any, dtype: type = np.float32, name: str = 'list') -> np.ndarray: def _data_to_2d_numpy(
data: Any,
dtype: "np.typing.DTypeLike" = np.float32,
name: str = 'list'
) -> np.ndarray:
"""Convert data to numpy 2-D array.""" """Convert data to numpy 2-D array."""
if _is_numpy_2d_array(data): if _is_numpy_2d_array(data):
return _cast_numpy_array_to_dtype(data, dtype) return _cast_numpy_array_to_dtype(data, dtype)
...@@ -612,7 +616,7 @@ def _c_int_array(data): ...@@ -612,7 +616,7 @@ def _c_int_array(data):
return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed
def _is_allowed_numpy_dtype(dtype) -> bool: def _is_allowed_numpy_dtype(dtype: type) -> bool:
float128 = getattr(np, 'float128', type(None)) float128 = getattr(np, 'float128', type(None))
return ( return (
issubclass(dtype, (np.integer, np.floating, np.bool_)) issubclass(dtype, (np.integer, np.floating, np.bool_))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment