Commit 41152eab authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python][docs] reworked predict method in sklearn wrapper and docs improvements (#1351)

* fixed docs

* reworker predict method of sklearn wrapper

* fixed encapsulation

* added test

* fixed consistency between docstring and params docs

* fixed verbose

* replaced predict_proba with predict in test

* fixed verbose again

* fixed fraction params descriptions

* added description of skip_drop and drop_rate constraints

* fixed subsample_freq consistency with C++ default value

* fixed nice look of params list

* made force splits json file example clickable

* fixed nice look of metrics list and added comma

* reduced warning in test about same param specified twice

* replaced pred_parameter with **kwargs in predict method

* added test for **kwargs in predict method

* fixed warnings

* fixed pylint
parent 64066805
...@@ -200,7 +200,7 @@ Learning Control Parameters ...@@ -200,7 +200,7 @@ Learning Control Parameters
- minimal sum hessian in one leaf. Like ``min_data_in_leaf``, it can be used to deal with over-fitting - minimal sum hessian in one leaf. Like ``min_data_in_leaf``, it can be used to deal with over-fitting
- ``feature_fraction``, default=\ ``1.0``, type=double, ``0.0 < feature_fraction < 1.0``, alias=\ ``sub_feature``, ``colsample_bytree`` - ``feature_fraction``, default=\ ``1.0``, type=double, ``0.0 < feature_fraction <= 1.0``, alias=\ ``sub_feature``, ``colsample_bytree``
- LightGBM will randomly select part of features on each iteration if ``feature_fraction`` smaller than ``1.0``. - LightGBM will randomly select part of features on each iteration if ``feature_fraction`` smaller than ``1.0``.
For example, if set to ``0.8``, will select 80% features before training each tree For example, if set to ``0.8``, will select 80% features before training each tree
...@@ -213,7 +213,7 @@ Learning Control Parameters ...@@ -213,7 +213,7 @@ Learning Control Parameters
- random seed for ``feature_fraction`` - random seed for ``feature_fraction``
- ``bagging_fraction``, default=\ ``1.0``, type=double, ``0.0 < bagging_fraction < 1.0``, alias=\ ``sub_row``, ``subsample`` - ``bagging_fraction``, default=\ ``1.0``, type=double, ``0.0 < bagging_fraction <= 1.0``, alias=\ ``sub_row``, ``subsample``
- like ``feature_fraction``, but this will randomly select part of data without resampling - like ``feature_fraction``, but this will randomly select part of data without resampling
...@@ -257,11 +257,11 @@ Learning Control Parameters ...@@ -257,11 +257,11 @@ Learning Control Parameters
- the minimal gain to perform split - the minimal gain to perform split
- ``drop_rate``, default=\ ``0.1``, type=double - ``drop_rate``, default=\ ``0.1``, type=double, ``0.0 <= drop_rate <= 1.0``
- only used in ``dart`` - only used in ``dart``
- ``skip_drop``, default=\ ``0.5``, type=double - ``skip_drop``, default=\ ``0.5``, type=double, ``0.0 <= skip_drop <= 1.0``
- only used in ``dart``, probability of skipping drop - only used in ``dart``, probability of skipping drop
...@@ -321,13 +321,13 @@ Learning Control Parameters ...@@ -321,13 +321,13 @@ Learning Control Parameters
- set this to larger value for more accurate result, but it will slow down the training speed - set this to larger value for more accurate result, but it will slow down the training speed
- ``monotone_constraint``, default=``None``, type=multi-int, alias=\ ``mc`` - ``monotone_constraint``, default=\ ``None``, type=multi-int, alias=\ ``mc``
- used for constraints of monotonic features - used for constraints of monotonic features
- ``1`` means increasing, ``-1`` means decreasing, ``0`` means non-constraint - ``1`` means increasing, ``-1`` means decreasing, ``0`` means non-constraint
- need to specific all features in order. For example, ``mc=-1,0,1`` means the decreasing for 1st feature, non-constraint for 2nd feature and increasing for the 3rd feature. - you need to specify all features in order. For example, ``mc=-1,0,1`` means the decreasing for 1st feature, non-constraint for 2nd feature and increasing for the 3rd feature
IO Parameters IO Parameters
------------- -------------
...@@ -528,7 +528,7 @@ IO Parameters ...@@ -528,7 +528,7 @@ IO Parameters
fields representing subsplits. Categorical splits are forced in a one-hot fashion, with ``left`` representing the split containing fields representing subsplits. Categorical splits are forced in a one-hot fashion, with ``left`` representing the split containing
the feature value and ``right`` representing other values. the feature value and ``right`` representing other values.
- see ``examples/binary_classification/forced_splits.json`` as an example. - see `this file <https://github.com/Microsoft/LightGBM/tree/master/examples/binary_classification/forced_splits.json>`__ as an example.
Objective Parameters Objective Parameters
-------------------- --------------------
...@@ -609,7 +609,7 @@ Metric Parameters ...@@ -609,7 +609,7 @@ Metric Parameters
- ``''`` (empty string or not specific), metric corresponding to specified objective will be used - ``''`` (empty string or not specific), metric corresponding to specified objective will be used
(this is possible only for pre-defined objective functions, otherwise no evaluation metric will be added) (this is possible only for pre-defined objective functions, otherwise no evaluation metric will be added)
- ``'None'`` (string **not** a ``None`` value), no metric registered, alias=\ ``na`` - ``'None'`` (string, **not** a ``None`` value), no metric registered, alias=\ ``na``
- ``l1``, absolute loss, alias=\ ``mean_absolute_error``, ``mae``, ``regression_l1`` - ``l1``, absolute loss, alias=\ ``mean_absolute_error``, ``mae``, ``regression_l1``
......
...@@ -13,8 +13,8 @@ from tempfile import NamedTemporaryFile ...@@ -13,8 +13,8 @@ from tempfile import NamedTemporaryFile
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import (DataFrame, Series, integer_types, json, from .compat import (DataFrame, LGBMDeprecationWarning, Series, integer_types,
json_default_with_numpy, numeric_types, range_, json, json_default_with_numpy, numeric_types, range_,
string_type) string_type)
from .libpath import find_lib_path from .libpath import find_lib_path
...@@ -1754,7 +1754,7 @@ class Booster(object): ...@@ -1754,7 +1754,7 @@ class Booster(object):
return json.loads(string_buffer.value.decode()) return json.loads(string_buffer.value.decode())
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, pred_contrib=False, def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, pred_parameter=None): data_has_header=False, is_reshape=True, pred_parameter=None, **kwargs):
"""Make a prediction. """Make a prediction.
Parameters Parameters
...@@ -1776,14 +1776,22 @@ class Booster(object): ...@@ -1776,14 +1776,22 @@ class Booster(object):
Used only if data is string. Used only if data is string.
is_reshape : bool, optional (default=True) is_reshape : bool, optional (default=True)
If True, result is reshaped to [nrow, ncol]. If True, result is reshaped to [nrow, ncol].
pred_parameter: dict or None, optional (default=None) pred_parameter : dict or None, optional (default=None)
Deprecated.
Other parameters for the prediction. Other parameters for the prediction.
**kwargs : other parameters for the prediction
Returns Returns
------- -------
result : numpy array result : numpy array
Prediction result. Prediction result.
""" """
if pred_parameter:
warnings.warn("pred_parameter is deprecated and will be removed in 2.2 version.\n"
"Please use kwargs instead.", LGBMDeprecationWarning)
pred_parameter.update(kwargs)
else:
pred_parameter = kwargs
predictor = self._to_predictor(pred_parameter) predictor = self._to_predictor(pred_parameter)
if num_iteration <= 0: if num_iteration <= 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
......
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
import warnings
from .basic import Dataset, LightGBMError from .basic import Dataset, LightGBMError
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase, LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength, _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMCheckClassificationTargets, _LGBMComputeSampleWeight, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, DataFrame) argc_, range_, DataFrame, LGBMDeprecationWarning)
from .engine import train from .engine import train
...@@ -131,7 +132,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -131,7 +132,7 @@ class LGBMModel(_LGBMModelBase):
learning_rate=0.1, n_estimators=100, learning_rate=0.1, n_estimators=100,
subsample_for_bin=200000, objective=None, class_weight=None, subsample_for_bin=200000, objective=None, class_weight=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=1, colsample_bytree=1., subsample=1., subsample_freq=0, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None, reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, **kwargs): n_jobs=-1, silent=True, **kwargs):
"""Construct a gradient boosting model. """Construct a gradient boosting model.
...@@ -177,7 +178,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -177,7 +178,7 @@ class LGBMModel(_LGBMModelBase):
Minimum number of data need in a child(leaf). Minimum number of data need in a child(leaf).
subsample : float, optional (default=1.) subsample : float, optional (default=1.)
Subsample ratio of the training instance. Subsample ratio of the training instance.
subsample_freq : int, optional (default=1) subsample_freq : int, optional (default=0)
Frequence of subsample, <=0 means no enable. Frequence of subsample, <=0 means no enable.
colsample_bytree : float, optional (default=1.) colsample_bytree : float, optional (default=1.)
Subsample ratio of columns when constructing each tree. Subsample ratio of columns when constructing each tree.
...@@ -484,7 +485,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -484,7 +485,8 @@ class LGBMModel(_LGBMModelBase):
del train_set, valid_sets del train_set, valid_sets
return self return self
def predict(self, X, raw_score=False, num_iteration=0): def predict(self, X, raw_score=False, num_iteration=-1,
pred_leaf=False, pred_contrib=False, **kwargs):
"""Return the predicted value for each sample. """Return the predicted value for each sample.
Parameters Parameters
...@@ -493,13 +495,23 @@ class LGBMModel(_LGBMModelBase): ...@@ -493,13 +495,23 @@ class LGBMModel(_LGBMModelBase):
Input features matrix. Input features matrix.
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
num_iteration : int, optional (default=0) num_iteration : int, optional (default=-1)
Limit number of iterations in the prediction; defaults to 0 (use all trees). Limit number of iterations in the prediction.
If <= 0, uses all trees (no limits).
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Whether to predict feature contributions.
**kwargs : other parameters for the prediction
Returns Returns
------- -------
predicted_result : array-like of shape = [n_samples] or shape = [n_samples, n_classes] predicted_result : array-like of shape = [n_samples] or shape = [n_samples, n_classes]
The predicted values. The predicted values.
X_leaves : array-like of shape = [n_samples, n_trees] or shape [n_samples, n_trees * n_classes]
If ``pred_leaf=True``, the predicted leaf every tree for each sample.
X_SHAP_values : array-like of shape = [n_samples, n_features + 1] or shape [n_samples, (n_features + 1) * n_classes]
If ``pred_contrib=True``, the each feature contributions for each sample.
""" """
if self._n_features is None: if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
...@@ -511,7 +523,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -511,7 +523,8 @@ class LGBMModel(_LGBMModelBase):
"match the input. Model n_features_ is %s and " "match the input. Model n_features_ is %s and "
"input n_features is %s " "input n_features is %s "
% (self._n_features, n_features)) % (self._n_features, n_features))
return self.booster_.predict(X, raw_score=raw_score, num_iteration=num_iteration) return self.booster_.predict(X, raw_score=raw_score, num_iteration=num_iteration,
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs)
def apply(self, X, num_iteration=0): def apply(self, X, num_iteration=0):
"""Return the predicted leaf every tree for each sample. """Return the predicted leaf every tree for each sample.
...@@ -528,6 +541,9 @@ class LGBMModel(_LGBMModelBase): ...@@ -528,6 +541,9 @@ class LGBMModel(_LGBMModelBase):
X_leaves : array-like of shape = [n_samples, n_trees] X_leaves : array-like of shape = [n_samples, n_trees]
The predicted leaf every tree for each sample. The predicted leaf every tree for each sample.
""" """
warnings.warn('apply method is deprecated and will be removed in 2.2 version.\n'
'Please use pred_leaf parameter of predict method instead.',
LGBMDeprecationWarning)
if self._n_features is None: if self._n_features is None:
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
if not isinstance(X, DataFrame): if not isinstance(X, DataFrame):
...@@ -617,13 +633,13 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase): ...@@ -617,13 +633,13 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
callbacks=callbacks) callbacks=callbacks)
return self return self
base_doc = LGBMModel.fit.__doc__ _base_doc = LGBMModel.fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_class_weight :')] fit.__doc__ = (_base_doc[:_base_doc.find('eval_class_weight :')]
+ base_doc[base_doc.find('eval_init_score :'):]) + _base_doc[_base_doc.find('eval_init_score :'):])
base_doc = fit.__doc__ _base_doc = fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_metric :')] fit.__doc__ = (_base_doc[:_base_doc.find('eval_metric :')]
+ 'eval_metric : string, list of strings, callable or None, optional (default="l2")\n' + 'eval_metric : string, list of strings, callable or None, optional (default="l2")\n'
+ base_doc[base_doc.find(' If string, it should be a built-in evaluation metric to use.'):]) + _base_doc[_base_doc.find(' If string, it should be a built-in evaluation metric to use.'):])
class LGBMClassifier(LGBMModel, _LGBMClassifierBase): class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
...@@ -678,17 +694,23 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -678,17 +694,23 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
callbacks=callbacks) callbacks=callbacks)
return self return self
base_doc = LGBMModel.fit.__doc__ _base_doc = LGBMModel.fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_metric :')] fit.__doc__ = (_base_doc[:_base_doc.find('eval_metric :')]
+ 'eval_metric : string, list of strings, callable or None, optional (default="logloss")\n' + 'eval_metric : string, list of strings, callable or None, optional (default="logloss")\n'
+ base_doc[base_doc.find(' If string, it should be a built-in evaluation metric to use.'):]) + _base_doc[_base_doc.find(' If string, it should be a built-in evaluation metric to use.'):])
def predict(self, X, raw_score=False, num_iteration=0): def predict(self, X, raw_score=False, num_iteration=-1,
class_probs = self.predict_proba(X, raw_score, num_iteration) pred_leaf=False, pred_contrib=False, **kwargs):
class_index = np.argmax(class_probs, axis=1) result = self.predict_proba(X, raw_score, num_iteration,
pred_leaf, pred_contrib, **kwargs)
if raw_score or pred_leaf or pred_contrib:
return result
else:
class_index = np.argmax(result, axis=1)
return self._le.inverse_transform(class_index) return self._le.inverse_transform(class_index)
def predict_proba(self, X, raw_score=False, num_iteration=0): def predict_proba(self, X, raw_score=False, num_iteration=-1,
pred_leaf=False, pred_contrib=False, **kwargs):
"""Return the predicted probability for each class for each sample. """Return the predicted probability for each class for each sample.
Parameters Parameters
...@@ -697,29 +719,30 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -697,29 +719,30 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
Input features matrix. Input features matrix.
raw_score : bool, optional (default=False) raw_score : bool, optional (default=False)
Whether to predict raw scores. Whether to predict raw scores.
num_iteration : int, optional (default=0) num_iteration : int, optional (default=-1)
Limit number of iterations in the prediction; defaults to 0 (use all trees). Limit number of iterations in the prediction.
If <= 0, uses all trees (no limits).
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Whether to predict feature contributions.
**kwargs : other parameters for the prediction
Returns Returns
------- -------
predicted_probability : array-like of shape = [n_samples, n_classes] predicted_probability : array-like of shape = [n_samples, n_classes]
The predicted probability for each class for each sample. The predicted probability for each class for each sample.
X_leaves : array-like of shape = [n_samples, n_trees * n_classes]
If ``pred_leaf=True``, the predicted leaf every tree for each sample.
X_SHAP_values : array-like of shape = [n_samples, (n_features + 1) * n_classes]
If ``pred_contrib=True``, the each feature contributions for each sample.
""" """
if self._n_features is None: result = super(LGBMClassifier, self).predict(X, raw_score, num_iteration,
raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.") pred_leaf, pred_contrib, **kwargs)
if not isinstance(X, DataFrame): if self._n_classes > 2 or pred_leaf or pred_contrib:
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) return result
n_features = X.shape[1]
if self._n_features != n_features:
raise ValueError("Number of features of the model must "
"match the input. Model n_features_ is %s and "
"input n_features is %s "
% (self._n_features, n_features))
class_probs = self.booster_.predict(X, raw_score=raw_score, num_iteration=num_iteration)
if self._n_classes > 2:
return class_probs
else: else:
return np.vstack((1. - class_probs, class_probs)).transpose() return np.vstack((1. - result, result)).transpose()
@property @property
def classes_(self): def classes_(self):
...@@ -772,13 +795,13 @@ class LGBMRanker(LGBMModel): ...@@ -772,13 +795,13 @@ class LGBMRanker(LGBMModel):
callbacks=callbacks) callbacks=callbacks)
return self return self
base_doc = LGBMModel.fit.__doc__ _base_doc = LGBMModel.fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_class_weight :')] fit.__doc__ = (_base_doc[:_base_doc.find('eval_class_weight :')]
+ base_doc[base_doc.find('eval_init_score :'):]) + _base_doc[_base_doc.find('eval_init_score :'):])
base_doc = fit.__doc__ _base_doc = fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_metric :')] fit.__doc__ = (_base_doc[:_base_doc.find('eval_metric :')]
+ 'eval_metric : string, list of strings, callable or None, optional (default="ndcg")\n' + 'eval_metric : string, list of strings, callable or None, optional (default="ndcg")\n'
+ base_doc[base_doc.find(' If string, it should be a built-in evaluation metric to use.'):base_doc.find('early_stopping_rounds :')] + _base_doc[_base_doc.find(' If string, it should be a built-in evaluation metric to use.'):_base_doc.find('early_stopping_rounds :')]
+ 'eval_at : list of int, optional (default=[1])\n' + 'eval_at : list of int, optional (default=[1])\n'
' The evaluation positions of NDCG.\n' ' The evaluation positions of NDCG.\n'
+ base_doc[base_doc.find(' early_stopping_rounds :'):]) + _base_doc[_base_doc.find(' early_stopping_rounds :'):])
...@@ -56,7 +56,7 @@ class TestBasic(unittest.TestCase): ...@@ -56,7 +56,7 @@ class TestBasic(unittest.TestCase):
# check early stopping is working. Make it stop very early, so the scores should be very close to zero # check early stopping is working. Make it stop very early, so the scores should be very close to zero
pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5} pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5}
pred_early_stopping = bst.predict(X_test, pred_parameter=pred_parameter) pred_early_stopping = bst.predict(X_test, **pred_parameter)
self.assertEqual(len(pred_from_matr), len(pred_early_stopping)) self.assertEqual(len(pred_from_matr), len(pred_early_stopping))
for preds in zip(pred_early_stopping, pred_from_matr): for preds in zip(pred_early_stopping, pred_from_matr):
# scores likely to be different, but prediction should still be the same # scores likely to be different, but prediction should still be the same
......
...@@ -319,12 +319,12 @@ class TestEngine(unittest.TestCase): ...@@ -319,12 +319,12 @@ class TestEngine(unittest.TestCase):
evals_result=evals_result) evals_result=evals_result)
pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5} pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5}
ret = multi_logloss(y_test, gbm.predict(X_test, pred_parameter=pred_parameter)) ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter))
self.assertLess(ret, 0.8) self.assertLess(ret, 0.8)
self.assertGreater(ret, 0.5) # loss will be higher than when evaluating the full model self.assertGreater(ret, 0.5) # loss will be higher than when evaluating the full model
pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 5.5} pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 5.5}
ret = multi_logloss(y_test, gbm.predict(X_test, pred_parameter=pred_parameter)) ret = multi_logloss(y_test, gbm.predict(X_test, **pred_parameter))
self.assertLess(ret, 0.2) self.assertLess(ret, 0.2)
def test_early_stopping(self): def test_early_stopping(self):
...@@ -537,7 +537,6 @@ class TestEngine(unittest.TestCase): ...@@ -537,7 +537,6 @@ class TestEngine(unittest.TestCase):
'objective': 'binary', 'objective': 'binary',
'metric': 'binary_logloss', 'metric': 'binary_logloss',
'verbose': -1, 'verbose': -1,
'num_iteration': 50 # test num_iteration in dict here
} }
lgb_train = lgb.Dataset(X_train, y_train) lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
......
...@@ -230,3 +230,48 @@ class TestSklearn(unittest.TestCase): ...@@ -230,3 +230,48 @@ class TestSklearn(unittest.TestCase):
np.testing.assert_almost_equal(pred0, pred2) np.testing.assert_almost_equal(pred0, pred2)
np.testing.assert_almost_equal(pred0, pred3) np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred_prob, pred4) np.testing.assert_almost_equal(pred_prob, pred4)
def test_predict(self):
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target,
test_size=0.2, random_state=42)
gbm = lgb.train({'objective': 'multiclass',
'num_class': 3,
'verbose': -1},
lgb.Dataset(X_train, y_train))
clf = lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train)
# Tests same probabilities
res_engine = gbm.predict(X_test)
res_sklearn = clf.predict_proba(X_test)
np.testing.assert_allclose(res_engine, res_sklearn)
# Tests same predictions
res_engine = np.argmax(gbm.predict(X_test), axis=1)
res_sklearn = clf.predict(X_test)
np.testing.assert_equal(res_engine, res_sklearn)
# Tests same raw scores
res_engine = gbm.predict(X_test, raw_score=True)
res_sklearn = clf.predict(X_test, raw_score=True)
np.testing.assert_allclose(res_engine, res_sklearn)
# Tests same leaf indices
res_engine = gbm.predict(X_test, pred_leaf=True)
res_sklearn = clf.predict(X_test, pred_leaf=True)
np.testing.assert_equal(res_engine, res_sklearn)
# Tests same feature contributions
res_engine = gbm.predict(X_test, pred_contrib=True)
res_sklearn = clf.predict(X_test, pred_contrib=True)
np.testing.assert_allclose(res_engine, res_sklearn)
# Tests other parameters for the prediction works
res_engine = gbm.predict(X_test)
res_sklearn_params = clf.predict_proba(X_test,
pred_early_stop=True,
pred_early_stop_margin=1.0)
self.assertRaises(AssertionError,
np.testing.assert_allclose,
res_engine, res_sklearn_params)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment