Unverified Commit 6de94ef8 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type annotations on some array methods in basic.py (#5813)

parent 216eaff7
...@@ -34,6 +34,10 @@ _ctypes_int_ptr = Union[ ...@@ -34,6 +34,10 @@ _ctypes_int_ptr = Union[
"ctypes._Pointer[ctypes.c_int32]", "ctypes._Pointer[ctypes.c_int32]",
"ctypes._Pointer[ctypes.c_int64]" "ctypes._Pointer[ctypes.c_int64]"
] ]
_ctypes_int_array = Union[
"ctypes.Array[ctypes._Pointer[ctypes.c_int32]]",
"ctypes.Array[ctypes._Pointer[ctypes.c_int64]]"
]
_ctypes_float_ptr = Union[ _ctypes_float_ptr = Union[
"ctypes._Pointer[ctypes.c_float]", "ctypes._Pointer[ctypes.c_float]",
"ctypes._Pointer[ctypes.c_double]" "ctypes._Pointer[ctypes.c_double]"
...@@ -589,13 +593,16 @@ def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray: ...@@ -589,13 +593,16 @@ def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray:
return data return data
def _c_float_array(data): def _c_float_array(
data: np.ndarray
) -> Tuple[_ctypes_float_ptr, int, np.ndarray]:
"""Get pointer of float numpy array / list.""" """Get pointer of float numpy array / list."""
if _is_1d_list(data): if _is_1d_list(data):
data = np.array(data, copy=False) data = np.array(data, copy=False)
if _is_numpy_1d_array(data): if _is_numpy_1d_array(data):
data = _convert_from_sliced_object(data) data = _convert_from_sliced_object(data)
assert data.flags.c_contiguous assert data.flags.c_contiguous
ptr_data: _ctypes_float_ptr
if data.dtype == np.float32: if data.dtype == np.float32:
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
type_data = _C_API_DTYPE_FLOAT32 type_data = _C_API_DTYPE_FLOAT32
...@@ -609,13 +616,16 @@ def _c_float_array(data): ...@@ -609,13 +616,16 @@ def _c_float_array(data):
return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed
def _c_int_array(data): def _c_int_array(
data: np.ndarray
) -> Tuple[_ctypes_int_ptr, int, np.ndarray]:
"""Get pointer of int numpy array / list.""" """Get pointer of int numpy array / list."""
if _is_1d_list(data): if _is_1d_list(data):
data = np.array(data, copy=False) data = np.array(data, copy=False)
if _is_numpy_1d_array(data): if _is_numpy_1d_array(data):
data = _convert_from_sliced_object(data) data = _convert_from_sliced_object(data)
assert data.flags.c_contiguous assert data.flags.c_contiguous
ptr_data: _ctypes_int_ptr
if data.dtype == np.int32: if data.dtype == np.int32:
ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)) ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
type_data = _C_API_DTYPE_INT32 type_data = _C_API_DTYPE_INT32
...@@ -1624,10 +1634,10 @@ class Dataset: ...@@ -1624,10 +1634,10 @@ class Dataset:
# c type: double** # c type: double**
# each double* element points to start of each column of sample data. # each double* element points to start of each column of sample data.
sample_col_ptr = (ctypes.POINTER(ctypes.c_double) * ncol)() sample_col_ptr: _ctypes_float_array = (ctypes.POINTER(ctypes.c_double) * ncol)()
# c type int** # c type int**
# each int* points to start of indices for each column # each int* points to start of indices for each column
indices_col_ptr = (ctypes.POINTER(ctypes.c_int32) * ncol)() indices_col_ptr: _ctypes_int_array = (ctypes.POINTER(ctypes.c_int32) * ncol)()
for i in range(ncol): for i in range(ncol):
sample_col_ptr[i] = _c_float_array(sample_data[i])[0] sample_col_ptr[i] = _c_float_array(sample_data[i])[0]
indices_col_ptr[i] = _c_int_array(sample_indices[i])[0] indices_col_ptr[i] = _c_int_array(sample_indices[i])[0]
...@@ -2374,6 +2384,7 @@ class Dataset: ...@@ -2374,6 +2384,7 @@ class Dataset:
dtype = np.int32 if field_name == 'group' else np.float32 dtype = np.int32 if field_name == 'group' else np.float32
data = _list_to_1d_numpy(data, dtype, name=field_name) data = _list_to_1d_numpy(data, dtype, name=field_name)
ptr_data: Union[_ctypes_float_ptr, _ctypes_int_ptr]
if data.dtype == np.float32 or data.dtype == np.float64: if data.dtype == np.float32 or data.dtype == np.float64:
ptr_data, type_data, _ = _c_float_array(data) ptr_data, type_data, _ = _c_float_array(data)
elif data.dtype == np.int32: elif data.dtype == np.int32:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment