"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "36732f23baa9e87617ea7700a9f59d7f53e313c6"
Unverified Commit 8670013d authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] [ci] fix mypy errors in Booster.__inner_predict() (#5852)

parent ef5acfb4
...@@ -3110,7 +3110,7 @@ class Booster: ...@@ -3110,7 +3110,7 @@ class Booster:
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
# buffer for inner predict # buffer for inner predict
self.__inner_predict_buffer = [None] self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None]
self.__is_predicted_cur_iter = [False] self.__is_predicted_cur_iter = [False]
self.__get_eval_info() self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical self.pandas_categorical = train_set.pandas_categorical
...@@ -4518,16 +4518,16 @@ class Booster: ...@@ -4518,16 +4518,16 @@ class Booster:
# avoid to predict many time in one iteration # avoid to predict many time in one iteration
if not self.__is_predicted_cur_iter[data_idx]: if not self.__is_predicted_cur_iter[data_idx]:
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) # type: ignore[union-attr]
_safe_call(_LIB.LGBM_BoosterGetPredict( _safe_call(_LIB.LGBM_BoosterGetPredict(
self.handle, self.handle,
ctypes.c_int(data_idx), ctypes.c_int(data_idx),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
data_ptr)) data_ptr))
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): # type: ignore[arg-type]
raise ValueError(f"Wrong length of predict results for data {data_idx}") raise ValueError(f"Wrong length of predict results for data {data_idx}")
self.__is_predicted_cur_iter[data_idx] = True self.__is_predicted_cur_iter[data_idx] = True
result = self.__inner_predict_buffer[data_idx] result: np.ndarray = self.__inner_predict_buffer[data_idx] # type: ignore[assignment]
if self.__num_class > 1: if self.__num_class > 1:
num_data = result.size // self.__num_class num_data = result.size // self.__num_class
result = result.reshape(num_data, self.__num_class, order='F') result = result.reshape(num_data, self.__num_class, order='F')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment