Unverified Commit 7d4d8975 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] make _InnerPredictor construction stricter (#5961)

parent ed28c84a
...@@ -33,6 +33,7 @@ __all__ = [ ...@@ -33,6 +33,7 @@ __all__ = [
'Sequence', 'Sequence',
] ]
_BoosterHandle = ctypes.c_void_p
_DatasetHandle = ctypes.c_void_p _DatasetHandle = ctypes.c_void_p
_ctypes_int_ptr = Union[ _ctypes_int_ptr = Union[
"ctypes._Pointer[ctypes.c_int32]", "ctypes._Pointer[ctypes.c_int32]",
...@@ -837,52 +838,98 @@ class _InnerPredictor: ...@@ -837,52 +838,98 @@ class _InnerPredictor:
def __init__( def __init__(
self, self,
model_file: Optional[Union[str, Path]] = None, booster_handle: _BoosterHandle,
booster_handle: Optional[ctypes.c_void_p] = None, pandas_categorical: Optional[List[List]],
pred_parameter: Optional[Dict[str, Any]] = None pred_parameter: Dict[str, Any],
manage_handle: bool
): ):
"""Initialize the _InnerPredictor. """Initialize the _InnerPredictor.
Parameters Parameters
---------- ----------
model_file : str, pathlib.Path or None, optional (default=None) booster_handle : object
Path to the model file.
booster_handle : object or None, optional (default=None)
Handle of Booster. Handle of Booster.
pred_parameter: dict or None, optional (default=None) pandas_categorical : list of list, or None
If provided, list of categories for ``pandas`` categorical columns.
Where the ``i``th element of the list contains the categories for the ``i``th categorical feature.
pred_parameter : dict
Other parameters for the prediction. Other parameters for the prediction.
manage_handle : bool
If ``True``, free the corresponding Booster on the C++ side when this Python object is deleted.
""" """
self._handle = ctypes.c_void_p()
self.__is_manage_handle = True
if model_file is not None:
"""Prediction task"""
out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
_c_str(str(model_file)),
ctypes.byref(out_num_iterations),
ctypes.byref(self._handle)))
out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self._handle,
ctypes.byref(out_num_class)))
self.num_class = out_num_class.value
self.num_total_iteration = out_num_iterations.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
elif booster_handle is not None:
self.__is_manage_handle = False
self._handle = booster_handle self._handle = booster_handle
self.__is_manage_handle = manage_handle
self.pandas_categorical = pandas_categorical
self.pred_parameter = _param_dict_to_str(pred_parameter)
out_num_class = ctypes.c_int(0) out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(
_LIB.LGBM_BoosterGetNumClasses(
self._handle, self._handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)
)
)
self.num_class = out_num_class.value self.num_class = out_num_class.value
self.num_total_iteration = self.current_iteration()
self.pandas_categorical = None
else:
raise TypeError('Need model_file or booster_handle to create a predictor')
pred_parameter = {} if pred_parameter is None else pred_parameter @classmethod
self.pred_parameter = _param_dict_to_str(pred_parameter) def from_booster(
cls,
booster: "Booster",
pred_parameter: Dict[str, Any]
) -> "_InnerPredictor":
"""Initialize an ``_InnerPredictor`` from a ``Booster``.
Parameters
----------
booster : Booster
Booster.
pred_parameter : dict
Other parameters for the prediction.
"""
out_cur_iter = ctypes.c_int(0)
_safe_call(
_LIB.LGBM_BoosterGetCurrentIteration(
booster._handle,
ctypes.byref(out_cur_iter)
)
)
return cls(
booster_handle=booster._handle,
pandas_categorical=booster.pandas_categorical,
pred_parameter=pred_parameter,
manage_handle=False
)
@classmethod
def from_model_file(
cls,
model_file: Union[str, Path],
pred_parameter: Dict[str, Any]
) -> "_InnerPredictor":
"""Initialize an ``_InnerPredictor`` from a text file containing a LightGBM model.
Parameters
----------
model_file : str or pathlib.Path
Path to the model file.
pred_parameter : dict
Other parameters for the prediction.
"""
booster_handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int(0)
_safe_call(
_LIB.LGBM_BoosterCreateFromModelfile(
_c_str(str(model_file)),
ctypes.byref(out_num_iterations),
ctypes.byref(booster_handle)
)
)
return cls(
booster_handle=booster_handle,
pandas_categorical=_load_pandas_categorical(file_name=model_file),
pred_parameter=pred_parameter,
manage_handle=True
)
def __del__(self) -> None: def __del__(self) -> None:
try: try:
...@@ -3046,7 +3093,7 @@ class Booster: ...@@ -3046,7 +3093,7 @@ class Booster:
model_str : str or None, optional (default=None) model_str : str or None, optional (default=None)
Model will be loaded from this string. Model will be loaded from this string.
""" """
self._handle = None self._handle = ctypes.c_void_p()
self._network = False self._network = False
self.__need_reload_eval_info = True self.__need_reload_eval_info = True
self._train_data_name = "training" self._train_data_name = "training"
...@@ -3097,7 +3144,6 @@ class Booster: ...@@ -3097,7 +3144,6 @@ class Booster:
# copy the parameters from train_set # copy the parameters from train_set
params.update(train_set.get_params()) params.update(train_set.get_params())
params_str = _param_dict_to_str(params) params_str = _param_dict_to_str(params)
self._handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_BoosterCreate( _safe_call(_LIB.LGBM_BoosterCreate(
train_set._handle, train_set._handle,
_c_str(params_str), _c_str(params_str),
...@@ -3126,7 +3172,6 @@ class Booster: ...@@ -3126,7 +3172,6 @@ class Booster:
elif model_file is not None: elif model_file is not None:
# Prediction task # Prediction task
out_num_iterations = ctypes.c_int(0) out_num_iterations = ctypes.c_int(0)
self._handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile( _safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
_c_str(str(model_file)), _c_str(str(model_file)),
ctypes.byref(out_num_iterations), ctypes.byref(out_num_iterations),
...@@ -3905,7 +3950,8 @@ class Booster: ...@@ -3905,7 +3950,8 @@ class Booster:
self : Booster self : Booster
Loaded Booster object. Loaded Booster object.
""" """
if self._handle is not None: # ensure that existing Booster is freed before replacing it
# with a new one createdfrom file
_safe_call(_LIB.LGBM_BoosterFree(self._handle)) _safe_call(_LIB.LGBM_BoosterFree(self._handle))
self._free_buffer() self._free_buffer()
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
...@@ -4106,7 +4152,10 @@ class Booster: ...@@ -4106,7 +4152,10 @@ class Booster:
Prediction result. Prediction result.
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
""" """
predictor = self._to_predictor(pred_parameter=deepcopy(kwargs)) predictor = _InnerPredictor.from_booster(
booster=self,
pred_parameter=deepcopy(kwargs),
)
if num_iteration is None: if num_iteration is None:
if start_iteration <= 0: if start_iteration <= 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
...@@ -4223,7 +4272,10 @@ class Booster: ...@@ -4223,7 +4272,10 @@ class Booster:
raise LightGBMError('Cannot refit due to null objective function.') raise LightGBMError('Cannot refit due to null objective function.')
if dataset_params is None: if dataset_params is None:
dataset_params = {} dataset_params = {}
predictor = self._to_predictor(pred_parameter=deepcopy(kwargs)) predictor = _InnerPredictor.from_booster(
booster=self,
pred_parameter=deepcopy(kwargs)
)
leaf_preds: np.ndarray = predictor.predict( # type: ignore[assignment] leaf_preds: np.ndarray = predictor.predict( # type: ignore[assignment]
data=data, data=data,
start_iteration=-1, start_iteration=-1,
...@@ -4327,15 +4379,6 @@ class Booster: ...@@ -4327,15 +4379,6 @@ class Booster:
) )
return self return self
def _to_predictor(
self,
pred_parameter: Dict[str, Any]
) -> _InnerPredictor:
"""Convert to predictor."""
predictor = _InnerPredictor(booster_handle=self._handle, pred_parameter=pred_parameter)
predictor.pandas_categorical = self.pandas_categorical
return predictor
def num_feature(self) -> int: def num_feature(self) -> int:
"""Get number of features. """Get number of features.
......
...@@ -183,10 +183,20 @@ def train( ...@@ -183,10 +183,20 @@ def train(
predictor: Optional[_InnerPredictor] = None predictor: Optional[_InnerPredictor] = None
if isinstance(init_model, (str, Path)): if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params) predictor = _InnerPredictor.from_model_file(
model_file=init_model,
pred_parameter=params
)
elif isinstance(init_model, Booster): elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(pred_parameter=dict(init_model.params, **params)) predictor = _InnerPredictor.from_booster(
init_iteration = predictor.num_total_iteration if predictor is not None else 0 booster=init_model,
pred_parameter=dict(init_model.params, **params)
)
if predictor is not None:
init_iteration = predictor.current_iteration()
else:
init_iteration = 0
train_set._update_params(params) \ train_set._update_params(params) \
._set_predictor(predictor) \ ._set_predictor(predictor) \
...@@ -685,9 +695,15 @@ def cv( ...@@ -685,9 +695,15 @@ def cv(
first_metric_only = params.get('first_metric_only', False) first_metric_only = params.get('first_metric_only', False)
if isinstance(init_model, (str, Path)): if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params) predictor = _InnerPredictor.from_model_file(
model_file=init_model,
pred_parameter=params
)
elif isinstance(init_model, Booster): elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(pred_parameter=dict(init_model.params, **params)) predictor = _InnerPredictor.from_booster(
booster=init_model,
pred_parameter=dict(init_model.params, **params)
)
else: else:
predictor = None predictor = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment