Unverified Commit 38554095 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors about ctypes pointers (#5779)

parent 125ca033
...@@ -30,6 +30,18 @@ __all__ = [ ...@@ -30,6 +30,18 @@ __all__ = [
] ]
_DatasetHandle = ctypes.c_void_p _DatasetHandle = ctypes.c_void_p
_ctypes_int_ptr = Union[
"ctypes._Pointer[ctypes.c_int32]",
"ctypes._Pointer[ctypes.c_int64]"
]
_ctypes_float_ptr = Union[
"ctypes._Pointer[ctypes.c_float]",
"ctypes._Pointer[ctypes.c_double]"
]
_ctypes_float_array = Union[
"ctypes.Array[ctypes._Pointer[ctypes.c_float]]",
"ctypes.Array[ctypes._Pointer[ctypes.c_double]]"
]
_LGBM_EvalFunctionResultType = Tuple[str, float, bool] _LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]] _LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool] _LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
...@@ -1222,11 +1234,13 @@ class _InnerPredictor: ...@@ -1222,11 +1234,13 @@ class _InnerPredictor:
ptr_data, type_ptr_data, _ = _c_float_array(csr.data) ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
csr_indices = csr.indices.astype(np.int32, copy=False) csr_indices = csr.indices.astype(np.int32, copy=False)
matrix_type = _C_API_MATRIX_TYPE_CSR matrix_type = _C_API_MATRIX_TYPE_CSR
out_ptr_indptr: _ctypes_int_ptr
if type_ptr_indptr == _C_API_DTYPE_INT32: if type_ptr_indptr == _C_API_DTYPE_INT32:
out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)() out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)()
else: else:
out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)() out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)()
out_ptr_indices = ctypes.POINTER(ctypes.c_int32)() out_ptr_indices = ctypes.POINTER(ctypes.c_int32)()
out_ptr_data: _ctypes_float_ptr
if type_ptr_data == _C_API_DTYPE_FLOAT32: if type_ptr_data == _C_API_DTYPE_FLOAT32:
out_ptr_data = ctypes.POINTER(ctypes.c_float)() out_ptr_data = ctypes.POINTER(ctypes.c_float)()
else: else:
...@@ -1317,11 +1331,13 @@ class _InnerPredictor: ...@@ -1317,11 +1331,13 @@ class _InnerPredictor:
ptr_data, type_ptr_data, _ = _c_float_array(csc.data) ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
csc_indices = csc.indices.astype(np.int32, copy=False) csc_indices = csc.indices.astype(np.int32, copy=False)
matrix_type = _C_API_MATRIX_TYPE_CSC matrix_type = _C_API_MATRIX_TYPE_CSC
out_ptr_indptr: _ctypes_int_ptr
if type_ptr_indptr == _C_API_DTYPE_INT32: if type_ptr_indptr == _C_API_DTYPE_INT32:
out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)() out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)()
else: else:
out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)() out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)()
out_ptr_indices = ctypes.POINTER(ctypes.c_int32)() out_ptr_indices = ctypes.POINTER(ctypes.c_int32)()
out_ptr_data: _ctypes_float_ptr
if type_ptr_data == _C_API_DTYPE_FLOAT32: if type_ptr_data == _C_API_DTYPE_FLOAT32:
out_ptr_data = ctypes.POINTER(ctypes.c_float)() out_ptr_data = ctypes.POINTER(ctypes.c_float)()
else: else:
...@@ -1973,6 +1989,7 @@ class Dataset: ...@@ -1973,6 +1989,7 @@ class Dataset:
"""Initialize data from a list of 2-D numpy matrices.""" """Initialize data from a list of 2-D numpy matrices."""
ncol = mats[0].shape[1] ncol = mats[0].shape[1]
nrow = np.empty((len(mats),), np.int32) nrow = np.empty((len(mats),), np.int32)
ptr_data: _ctypes_float_array
if mats[0].dtype == np.float64: if mats[0].dtype == np.float64:
ptr_data = (ctypes.POINTER(ctypes.c_double) * len(mats))() ptr_data = (ctypes.POINTER(ctypes.c_double) * len(mats))()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment