Commit 8f17b30c authored by Nikita Titov's avatar Nikita Titov Committed by Tsukasa OMOTO
Browse files

[python] use None for best_iter and -1 for all iters (#1575)

* refined num_iteration argument in python

* hotfix
parent c78f26ed
...@@ -1421,7 +1421,7 @@ class Booster(object): ...@@ -1421,7 +1421,7 @@ class Booster(object):
return self.__deepcopy__(None) return self.__deepcopy__(None)
def __deepcopy__(self, _): def __deepcopy__(self, _):
model_str = self._save_model_to_string() model_str = self._save_model_to_string(num_iteration=-1)
booster = Booster({'model_str': model_str}) booster = Booster({'model_str': model_str})
booster.pandas_categorical = self.pandas_categorical booster.pandas_categorical = self.pandas_categorical
return booster return booster
...@@ -1432,7 +1432,7 @@ class Booster(object): ...@@ -1432,7 +1432,7 @@ class Booster(object):
this.pop('train_set', None) this.pop('train_set', None)
this.pop('valid_sets', None) this.pop('valid_sets', None)
if handle is not None: if handle is not None:
this["handle"] = self._save_model_to_string() this["handle"] = self._save_model_to_string(num_iteration=-1)
return this return this
def __setstate__(self, state): def __setstate__(self, state):
...@@ -1710,18 +1710,19 @@ class Booster(object): ...@@ -1710,18 +1710,19 @@ class Booster(object):
return [item for i in range_(1, self.__num_dataset) return [item for i in range_(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)] for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]
def save_model(self, filename, num_iteration=-1): def save_model(self, filename, num_iteration=None):
"""Save Booster to file. """Save Booster to file.
Parameters Parameters
---------- ----------
filename : string filename : string
Filename to save Booster. Filename to save Booster.
num_iteration: int, optional (default=-1) num_iteration : int or None, optional (default=None)
Index of the iteration that should to saved. Index of the iteration that should be saved.
If <0, the best iteration (if exists) is saved. If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
If <= 0, all iterations are saved.
""" """
if num_iteration <= 0: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
_safe_call(_LIB.LGBM_BoosterSaveModel( _safe_call(_LIB.LGBM_BoosterSaveModel(
self.handle, self.handle,
...@@ -1748,9 +1749,9 @@ class Booster(object): ...@@ -1748,9 +1749,9 @@ class Booster(object):
print('Finished loading model, total used %d iterations' % (int(out_num_iterations.value))) print('Finished loading model, total used %d iterations' % (int(out_num_iterations.value)))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
def _save_model_to_string(self, num_iteration=-1): def _save_model_to_string(self, num_iteration=None):
"""[Private] Save model to string""" """[Private] Save model to string"""
if num_iteration <= 0: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
buffer_len = 1 << 20 buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
...@@ -1775,21 +1776,22 @@ class Booster(object): ...@@ -1775,21 +1776,22 @@ class Booster(object):
ptr_string_buffer)) ptr_string_buffer))
return string_buffer.value.decode() return string_buffer.value.decode()
def dump_model(self, num_iteration=-1): def dump_model(self, num_iteration=None):
"""Dump Booster to json format. """Dump Booster to json format.
Parameters Parameters
---------- ----------
num_iteration: int, optional (default=-1) num_iteration : int or None, optional (default=None)
Index of the iteration that should to dumped. Index of the iteration that should be dumped.
If <0, the best iteration (if exists) is dumped. If None, if the best iteration exists, it is dumped; otherwise, all iterations are dumped.
If <= 0, all iterations are dumped.
Returns Returns
------- -------
json_repr : dict json_repr : dict
Json format of Booster. Json format of Booster.
""" """
if num_iteration <= 0: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
buffer_len = 1 << 20 buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
...@@ -1814,7 +1816,7 @@ class Booster(object): ...@@ -1814,7 +1816,7 @@ class Booster(object):
ptr_string_buffer)) ptr_string_buffer))
return json.loads(string_buffer.value.decode()) return json.loads(string_buffer.value.decode())
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, pred_contrib=False, def predict(self, data, num_iteration=None, raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, pred_parameter=None, **kwargs): data_has_header=False, is_reshape=True, pred_parameter=None, **kwargs):
"""Make a prediction. """Make a prediction.
...@@ -1823,9 +1825,11 @@ class Booster(object): ...@@ -1823,9 +1825,11 @@ class Booster(object):
data : string, numpy array or scipy.sparse data : string, numpy array or scipy.sparse
Data source for prediction. Data source for prediction.
If string, it represents the path to txt file. If string, it represents the path to txt file.
num_iteration : int, optional (default=-1) num_iteration : int or None, optional (default=None)
Iteration used for prediction. Limit number of iterations in the prediction.
If <0, the best iteration (if exists) is used for prediction. If None, if the best iteration exists, it is used; otherwise, all iterations are used.
If <= 0, all iterations are used (no limits).
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
pred_leaf : bool, optional (default=False) pred_leaf : bool, optional (default=False)
...@@ -1854,7 +1858,7 @@ class Booster(object): ...@@ -1854,7 +1858,7 @@ class Booster(object):
else: else:
pred_parameter = kwargs pred_parameter = kwargs
predictor = self._to_predictor(pred_parameter) predictor = self._to_predictor(pred_parameter)
if num_iteration <= 0: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape) return predictor.predict(data, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape)
......
...@@ -492,7 +492,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -492,7 +492,7 @@ class LGBMModel(_LGBMModelBase):
del train_set, valid_sets del train_set, valid_sets
return self return self
def predict(self, X, raw_score=False, num_iteration=-1, def predict(self, X, raw_score=False, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs): pred_leaf=False, pred_contrib=False, **kwargs):
"""Return the predicted value for each sample. """Return the predicted value for each sample.
...@@ -502,9 +502,10 @@ class LGBMModel(_LGBMModelBase): ...@@ -502,9 +502,10 @@ class LGBMModel(_LGBMModelBase):
Input features matrix. Input features matrix.
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
num_iteration : int, optional (default=-1) num_iteration : int or None, optional (default=None)
Limit number of iterations in the prediction. Limit number of iterations in the prediction.
If <= 0, uses all trees (no limits). If None, if the best iteration exists, it is used; otherwise, all trees are used.
If <= 0, all trees are used (no limits).
pred_leaf : bool, optional (default=False) pred_leaf : bool, optional (default=False)
Whether to predict leaf index. Whether to predict leaf index.
pred_contrib : bool, optional (default=False) pred_contrib : bool, optional (default=False)
...@@ -708,7 +709,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -708,7 +709,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
+ 'eval_metric : string, list of strings, callable or None, optional (default="logloss")\n' + 'eval_metric : string, list of strings, callable or None, optional (default="logloss")\n'
+ _base_doc[_base_doc.find(' If string, it should be a built-in evaluation metric to use.'):]) + _base_doc[_base_doc.find(' If string, it should be a built-in evaluation metric to use.'):])
def predict(self, X, raw_score=False, num_iteration=-1, def predict(self, X, raw_score=False, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs): pred_leaf=False, pred_contrib=False, **kwargs):
result = self.predict_proba(X, raw_score, num_iteration, result = self.predict_proba(X, raw_score, num_iteration,
pred_leaf, pred_contrib, **kwargs) pred_leaf, pred_contrib, **kwargs)
...@@ -718,7 +719,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -718,7 +719,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
class_index = np.argmax(result, axis=1) class_index = np.argmax(result, axis=1)
return self._le.inverse_transform(class_index) return self._le.inverse_transform(class_index)
def predict_proba(self, X, raw_score=False, num_iteration=-1, def predict_proba(self, X, raw_score=False, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs): pred_leaf=False, pred_contrib=False, **kwargs):
"""Return the predicted probability for each class for each sample. """Return the predicted probability for each class for each sample.
...@@ -728,9 +729,10 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -728,9 +729,10 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
Input features matrix. Input features matrix.
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
num_iteration : int, optional (default=-1) num_iteration : int or None, optional (default=None)
Limit number of iterations in the prediction. Limit number of iterations in the prediction.
If <= 0, uses all trees (no limits). If None, if the best iteration exists, it is used; otherwise, all trees are used.
If <= 0, all trees are used (no limits).
pred_leaf : bool, optional (default=False) pred_leaf : bool, optional (default=False)
Whether to predict leaf index. Whether to predict leaf index.
pred_contrib : bool, optional (default=False) pred_contrib : bool, optional (default=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment