sklearn.py 48.8 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Scikit-learn wrapper interface for LightGBM."""
3
import copy
4
5
import warnings

6
7
from inspect import signature

wxchan's avatar
wxchan committed
8
import numpy as np
9

10
from .basic import Dataset, LightGBMError, _ConfigAliases
11
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
12
                     LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
13
                     _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckSampleWeight,
14
                     _LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
15
                     DataFrame, DataTable)
wxchan's avatar
wxchan committed
16
from .engine import train
17

wxchan's avatar
wxchan committed
18

19
class _ObjectiveFunctionWrapper:
20
    """Proxy class for objective function."""
21

22
23
    def __init__(self, func):
        """Construct a proxy class.
24

25
26
        This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
        as expected by ``lightgbm.engine.train``.
27

28
29
30
31
32
33
34
35
36
37
38
        Parameters
        ----------
        func : callable
            Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group)
            and returns (grad, hess):

                y_true : array-like of shape = [n_samples]
                    The target values.
                y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
                    The predicted values.
                group : array-like
39
40
41
42
                    Group/query data.
                    Only used in the learning-to-rank task.
                    sum(group) = n_samples.
                    For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, etc.
43
44
45
46
                grad : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
                    The value of the first order derivative (gradient) for each sample point.
                hess : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
                    The value of the second order derivative (Hessian) for each sample point.
wxchan's avatar
wxchan committed
47

Nikita Titov's avatar
Nikita Titov committed
48
49
        .. note::

50
            For binary task, the y_pred is margin.
Nikita Titov's avatar
Nikita Titov committed
51
52
53
            For multi-class task, the y_pred is group by class_id first, then group by row_id.
            If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i]
            and you should group grad and hess in this way as well.
54
55
        """
        self.func = func
wxchan's avatar
wxchan committed
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def __call__(self, preds, dataset):
        """Call passed function with appropriate arguments.

        Parameters
        ----------
        preds : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
            The predicted values.
        dataset : Dataset
            The training dataset.

        Returns
        -------
        grad : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
            The value of the first order derivative (gradient) for each sample point.
        hess : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
            The value of the second order derivative (Hessian) for each sample point.
        """
wxchan's avatar
wxchan committed
74
        labels = dataset.get_label()
75
        argc = len(signature(self.func).parameters)
76
        if argc == 2:
77
            grad, hess = self.func(labels, preds)
78
        elif argc == 3:
79
            grad, hess = self.func(labels, preds, dataset.get_group())
80
        else:
wxchan's avatar
wxchan committed
81
            raise TypeError("Self-defined objective function should have 2 or 3 arguments, got %d" % argc)
wxchan's avatar
wxchan committed
82
83
84
85
86
87
88
89
90
91
92
        """weighted for objective"""
        weight = dataset.get_weight()
        if weight is not None:
            """only one class"""
            if len(weight) == len(grad):
                grad = np.multiply(grad, weight)
                hess = np.multiply(hess, weight)
            else:
                num_data = len(weight)
                num_class = len(grad) // num_data
                if num_class * num_data != len(grad):
93
                    raise ValueError("Length of grad and hess should equal to num_class * num_data")
94
95
                for k in range(num_class):
                    for i in range(num_data):
wxchan's avatar
wxchan committed
96
97
98
99
100
                        idx = k * num_data + i
                        grad[idx] *= weight[i]
                        hess[idx] *= weight[i]
        return grad, hess

wxchan's avatar
wxchan committed
101

102
class _EvalFunctionWrapper:
103
    """Proxy class for evaluation function."""
104

105
106
    def __init__(self, func):
        """Construct a proxy class.
107

108
109
        This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
        as expected by ``lightgbm.engine.train``.
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        Parameters
        ----------
        func : callable
            Expects a callable with following signatures:
            ``func(y_true, y_pred)``,
            ``func(y_true, y_pred, weight)``
            or ``func(y_true, y_pred, weight, group)``
            and returns (eval_name, eval_result, is_higher_better) or
            list of (eval_name, eval_result, is_higher_better):

                y_true : array-like of shape = [n_samples]
                    The target values.
                y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
                    The predicted values.
                weight : array-like of shape = [n_samples]
                    The weight of samples.
                group : array-like
128
129
130
131
                    Group/query data.
                    Only used in the learning-to-rank task.
                    sum(group) = n_samples.
                    For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, etc.
132
                eval_name : string
133
                    The name of evaluation function (without whitespaces).
134
135
136
137
138
                eval_result : float
                    The eval result.
                is_higher_better : bool
                    Is eval result higher better, e.g. AUC is ``is_higher_better``.

Nikita Titov's avatar
Nikita Titov committed
139
140
        .. note::

141
            For binary task, the y_pred is probability of positive class (or margin in case of custom ``objective``).
Nikita Titov's avatar
Nikita Titov committed
142
143
            For multi-class task, the y_pred is group by class_id first, then group by row_id.
            If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i].
144
145
        """
        self.func = func
146

147
148
    def __call__(self, preds, dataset):
        """Call passed function with appropriate arguments.
149

150
151
152
153
154
155
156
157
158
159
        Parameters
        ----------
        preds : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
            The predicted values.
        dataset : Dataset
            The training dataset.

        Returns
        -------
        eval_name : string
160
            The name of evaluation function (without whitespaces).
161
162
163
164
165
        eval_result : float
            The eval result.
        is_higher_better : bool
            Is eval result higher better, e.g. AUC is ``is_higher_better``.
        """
166
        labels = dataset.get_label()
167
        argc = len(signature(self.func).parameters)
168
        if argc == 2:
169
            return self.func(labels, preds)
170
        elif argc == 3:
171
            return self.func(labels, preds, dataset.get_weight())
172
        elif argc == 4:
173
            return self.func(labels, preds, dataset.get_weight(), dataset.get_group())
174
        else:
wxchan's avatar
wxchan committed
175
            raise TypeError("Self-defined eval function should have 2, 3 or 4 arguments, got %d" % argc)
176

wxchan's avatar
wxchan committed
177

178
179
class LGBMModel(_LGBMModelBase):
    """Implementation of the scikit-learn API for LightGBM."""
wxchan's avatar
wxchan committed
180

181
    def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
182
                 learning_rate=0.1, n_estimators=100,
183
                 subsample_for_bin=200000, objective=None, class_weight=None,
184
                 min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
185
                 subsample=1., subsample_freq=0, colsample_bytree=1.,
186
                 reg_alpha=0., reg_lambda=0., random_state=None,
187
                 n_jobs=-1, silent=True, importance_type='split', **kwargs):
188
        r"""Construct a gradient boosting model.
wxchan's avatar
wxchan committed
189
190
191

        Parameters
        ----------
192
        boosting_type : string, optional (default='gbdt')
193
194
195
196
197
            'gbdt', traditional Gradient Boosting Decision Tree.
            'dart', Dropouts meet Multiple Additive Regression Trees.
            'goss', Gradient-based One-Side Sampling.
            'rf', Random Forest.
        num_leaves : int, optional (default=31)
wxchan's avatar
wxchan committed
198
            Maximum tree leaves for base learners.
199
        max_depth : int, optional (default=-1)
200
            Maximum tree depth for base learners, <=0 means no limit.
201
        learning_rate : float, optional (default=0.1)
202
            Boosting learning rate.
203
204
205
            You can use ``callbacks`` parameter of ``fit`` method to shrink/adapt learning rate
            in training using ``reset_parameter`` callback.
            Note, that this will ignore the ``learning_rate`` argument in training.
206
        n_estimators : int, optional (default=100)
wxchan's avatar
wxchan committed
207
            Number of boosted trees to fit.
208
        subsample_for_bin : int, optional (default=200000)
wxchan's avatar
wxchan committed
209
            Number of samples for constructing bins.
210
        objective : string, callable or None, optional (default=None)
wxchan's avatar
wxchan committed
211
212
            Specify the learning task and the corresponding learning objective or
            a custom objective function to be used (see note below).
213
            Default: 'regression' for LGBMRegressor, 'binary' or 'multiclass' for LGBMClassifier, 'lambdarank' for LGBMRanker.
214
215
216
217
        class_weight : dict, 'balanced' or None, optional (default=None)
            Weights associated with classes in the form ``{class_label: weight}``.
            Use this parameter only for multi-class classification task;
            for binary classification task you may use ``is_unbalance`` or ``scale_pos_weight`` parameters.
218
219
220
            Note, that the usage of all these parameters will result in poor estimates of the individual class probabilities.
            You may want to consider performing probability calibration
            (https://scikit-learn.org/stable/modules/calibration.html) of your model.
221
222
223
            The 'balanced' mode uses the values of y to automatically adjust weights
            inversely proportional to class frequencies in the input data as ``n_samples / (n_classes * np.bincount(y))``.
            If None, all classes are supposed to have weight one.
224
            Note, that these weights will be multiplied with ``sample_weight`` (passed through the ``fit`` method)
225
            if ``sample_weight`` is specified.
226
        min_split_gain : float, optional (default=0.)
wxchan's avatar
wxchan committed
227
            Minimum loss reduction required to make a further partition on a leaf node of the tree.
228
        min_child_weight : float, optional (default=1e-3)
229
            Minimum sum of instance weight (hessian) needed in a child (leaf).
230
        min_child_samples : int, optional (default=20)
231
            Minimum number of data needed in a child (leaf).
232
        subsample : float, optional (default=1.)
wxchan's avatar
wxchan committed
233
            Subsample ratio of the training instance.
234
        subsample_freq : int, optional (default=0)
235
236
            Frequence of subsample, <=0 means no enable.
        colsample_bytree : float, optional (default=1.)
wxchan's avatar
wxchan committed
237
            Subsample ratio of columns when constructing each tree.
238
        reg_alpha : float, optional (default=0.)
239
            L1 regularization term on weights.
240
        reg_lambda : float, optional (default=0.)
241
            L2 regularization term on weights.
242
        random_state : int, RandomState object or None, optional (default=None)
wxchan's avatar
wxchan committed
243
            Random number seed.
244
245
246
            If int, this number is used to seed the C++ code.
            If RandomState object (numpy), a random integer is picked based on its state to seed the C++ code.
            If None, default seeds in C++ code are used.
247
        n_jobs : int, optional (default=-1)
248
            Number of parallel threads.
249
        silent : bool, optional (default=True)
wxchan's avatar
wxchan committed
250
            Whether to print messages while running boosting.
251
        importance_type : string, optional (default='split')
252
            The type of feature importance to be filled into ``feature_importances_``.
253
254
255
256
            If 'split', result contains numbers of times the feature is used in a model.
            If 'gain', result contains total gains of splits which use the feature.
        **kwargs
            Other parameters for the model.
wxchan's avatar
wxchan committed
257
            Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters.
258

Nikita Titov's avatar
Nikita Titov committed
259
260
261
            .. warning::

                \*\*kwargs is not supported in sklearn, it may cause unexpected issues.
wxchan's avatar
wxchan committed
262
263
264

        Note
        ----
265
266
        A custom objective function can be provided for the ``objective`` parameter.
        In this case, it should have the signature
267
268
        ``objective(y_true, y_pred) -> grad, hess`` or
        ``objective(y_true, y_pred, group) -> grad, hess``:
wxchan's avatar
wxchan committed
269

Nikita Titov's avatar
Nikita Titov committed
270
            y_true : array-like of shape = [n_samples]
271
                The target values.
Nikita Titov's avatar
Nikita Titov committed
272
            y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
273
                The predicted values.
Nikita Titov's avatar
Nikita Titov committed
274
            group : array-like
275
276
277
278
                Group/query data.
                Only used in the learning-to-rank task.
                sum(group) = n_samples.
                For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, etc.
Nikita Titov's avatar
Nikita Titov committed
279
            grad : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
280
                The value of the first order derivative (gradient) for each sample point.
Nikita Titov's avatar
Nikita Titov committed
281
            hess : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
282
                The value of the second order derivative (Hessian) for each sample point.
wxchan's avatar
wxchan committed
283

284
        For binary task, the y_pred is margin.
285
286
287
        For multi-class task, the y_pred is group by class_id first, then group by row_id.
        If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i]
        and you should group grad and hess in this way as well.
wxchan's avatar
wxchan committed
288
        """
wxchan's avatar
wxchan committed
289
        if not SKLEARN_INSTALLED:
290
            raise LightGBMError('Scikit-learn is required for this module')
wxchan's avatar
wxchan committed
291

292
        self.boosting_type = boosting_type
293
        self.objective = objective
wxchan's avatar
wxchan committed
294
295
296
297
        self.num_leaves = num_leaves
        self.max_depth = max_depth
        self.learning_rate = learning_rate
        self.n_estimators = n_estimators
wxchan's avatar
wxchan committed
298
        self.subsample_for_bin = subsample_for_bin
wxchan's avatar
wxchan committed
299
300
301
302
303
304
305
306
        self.min_split_gain = min_split_gain
        self.min_child_weight = min_child_weight
        self.min_child_samples = min_child_samples
        self.subsample = subsample
        self.subsample_freq = subsample_freq
        self.colsample_bytree = colsample_bytree
        self.reg_alpha = reg_alpha
        self.reg_lambda = reg_lambda
307
308
        self.random_state = random_state
        self.n_jobs = n_jobs
wxchan's avatar
wxchan committed
309
        self.silent = silent
310
        self.importance_type = importance_type
wxchan's avatar
wxchan committed
311
        self._Booster = None
312
313
314
315
        self._evals_result = None
        self._best_score = None
        self._best_iteration = None
        self._other_params = {}
316
        self._objective = objective
317
        self.class_weight = class_weight
318
319
        self._class_weight = None
        self._class_map = None
320
        self._n_features = None
321
        self._n_features_in = None
322
323
        self._classes = None
        self._n_classes = None
324
        self.set_params(**kwargs)
wxchan's avatar
wxchan committed
325

Nikita Titov's avatar
Nikita Titov committed
326
    def _more_tags(self):
327
328
329
330
331
332
333
334
335
336
        return {
            'allow_nan': True,
            'X_types': ['2darray', 'sparse', '1dlabels'],
            '_xfail_checks': {
                'check_no_attributes_set_in_init':
                'scikit-learn incorrectly asserts that private attributes '
                'cannot be set in __init__: '
                '(see https://github.com/microsoft/LightGBM/issues/2628)'
            }
        }
Nikita Titov's avatar
Nikita Titov committed
337

wxchan's avatar
wxchan committed
338
    def get_params(self, deep=True):
339
340
341
342
343
344
345
346
347
348
349
350
351
        """Get parameters for this estimator.

        Parameters
        ----------
        deep : bool, optional (default=True)
            If True, will return the parameters for this estimator and
            contained subobjects that are estimators.

        Returns
        -------
        params : dict
            Parameter names mapped to their values.
        """
352
        params = super().get_params(deep=deep)
353
        params.update(self._other_params)
wxchan's avatar
wxchan committed
354
355
356
        return params

    def set_params(self, **params):
357
358
359
360
361
362
363
364
365
366
367
368
        """Set the parameters of this estimator.

        Parameters
        ----------
        **params
            Parameter names with their new values.

        Returns
        -------
        self : object
            Returns self.
        """
wxchan's avatar
wxchan committed
369
370
        for key, value in params.items():
            setattr(self, key, value)
371
372
            if hasattr(self, '_' + key):
                setattr(self, '_' + key, value)
373
            self._other_params[key] = value
wxchan's avatar
wxchan committed
374
        return self
wxchan's avatar
wxchan committed
375

Guolin Ke's avatar
Guolin Ke committed
376
    def fit(self, X, y,
377
            sample_weight=None, init_score=None, group=None,
378
            eval_set=None, eval_names=None, eval_sample_weight=None,
379
380
            eval_class_weight=None, eval_init_score=None, eval_group=None,
            eval_metric=None, early_stopping_rounds=None, verbose=True,
381
382
            feature_name='auto', categorical_feature='auto',
            callbacks=None, init_model=None):
383
        """Build a gradient boosting model from the training set (X, y).
wxchan's avatar
wxchan committed
384
385
386

        Parameters
        ----------
387
388
389
390
391
392
393
394
        X : array-like or sparse matrix of shape = [n_samples, n_features]
            Input feature matrix.
        y : array-like of shape = [n_samples]
            The target values (class labels in classification, real numbers in regression).
        sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
            Weights of training data.
        init_score : array-like of shape = [n_samples] or None, optional (default=None)
            Init score of training data.
395
        group : array-like or None, optional (default=None)
396
397
398
399
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, etc.
400
        eval_set : list or None, optional (default=None)
401
            A list of (X, y) tuple pairs to use as validation sets.
402
        eval_names : list of strings or None, optional (default=None)
403
404
405
            Names of eval_set.
        eval_sample_weight : list of arrays or None, optional (default=None)
            Weights of eval data.
406
407
        eval_class_weight : list or None, optional (default=None)
            Class weights of eval data.
408
409
410
411
        eval_init_score : list of arrays or None, optional (default=None)
            Init score of eval data.
        eval_group : list of arrays or None, optional (default=None)
            Group data of eval data.
412
        eval_metric : string, callable, list or None, optional (default=None)
413
            If string, it should be a built-in evaluation metric to use.
414
            If callable, it should be a custom evaluation metric, see note below for more details.
415
            If list, it can be a list of built-in metrics, a list of custom evaluation metrics, or a mix of both.
Misha Lisovyi's avatar
Misha Lisovyi committed
416
            In either case, the ``metric`` from the model parameters will be evaluated and used as well.
417
            Default: 'l2' for LGBMRegressor, 'logloss' for LGBMClassifier, 'ndcg' for LGBMRanker.
418
419
        early_stopping_rounds : int or None, optional (default=None)
            Activates early stopping. The model will train until the validation score stops improving.
420
            Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
421
            to continue training.
422
423
            Requires at least one validation data and one metric.
            If there's more than one, will check all of them. But the training data is ignored anyway.
424
425
            To check only the first metric, set the ``first_metric_only`` parameter to ``True``
            in additional parameters ``**kwargs`` of the model constructor.
426
427
428
429
430
431
        verbose : bool or int, optional (default=True)
            Requires at least one evaluation data.
            If True, the eval metric on the eval set is printed at each boosting stage.
            If int, the eval metric on the eval set is printed at every ``verbose`` boosting stage.
            The last boosting stage or the boosting stage found by using ``early_stopping_rounds`` is also printed.

Nikita Titov's avatar
Nikita Titov committed
432
433
            .. rubric:: Example

434
435
436
            With ``verbose`` = 4 and at least one item in ``eval_set``,
            an evaluation metric is printed every 4 (instead of 1) boosting stages.

437
        feature_name : list of strings or 'auto', optional (default='auto')
438
439
            Feature names.
            If 'auto' and data is pandas DataFrame, data columns names are used.
440
        categorical_feature : list of strings or int, or 'auto', optional (default='auto')
441
442
            Categorical features.
            If list of int, interpreted as indices.
443
            If list of strings, interpreted as feature names (need to specify ``feature_name`` as well).
444
            If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
445
            All values in categorical features should be less than int32 max value (2147483647).
446
            Large values could be memory consuming. Consider using consecutive integers starting from zero.
447
            All negative values in categorical features will be treated as missing values.
448
            The output cannot be monotonically constrained with respect to a categorical feature.
449
        callbacks : list of callback functions or None, optional (default=None)
450
            List of callback functions that are applied at each iteration.
451
            See Callbacks in Python API for more information.
452
453
        init_model : string, Booster, LGBMModel or None, optional (default=None)
            Filename of LightGBM model, Booster instance or LGBMModel instance used for continue training.
454

455
456
457
458
459
        Returns
        -------
        self : object
            Returns self.

460
461
        Note
        ----
462
        Custom eval function expects a callable with following signatures:
463
        ``func(y_true, y_pred)``, ``func(y_true, y_pred, weight)`` or
464
        ``func(y_true, y_pred, weight, group)``
465
466
        and returns (eval_name, eval_result, is_higher_better) or
        list of (eval_name, eval_result, is_higher_better):
467

Nikita Titov's avatar
Nikita Titov committed
468
            y_true : array-like of shape = [n_samples]
469
                The target values.
470
            y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
471
                The predicted values.
Nikita Titov's avatar
Nikita Titov committed
472
            weight : array-like of shape = [n_samples]
473
                The weight of samples.
Nikita Titov's avatar
Nikita Titov committed
474
            group : array-like
475
476
477
478
                Group/query data.
                Only used in the learning-to-rank task.
                sum(group) = n_samples.
                For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, etc.
Nikita Titov's avatar
Nikita Titov committed
479
            eval_name : string
480
                The name of evaluation function (without whitespaces).
Nikita Titov's avatar
Nikita Titov committed
481
            eval_result : float
482
                The eval result.
483
484
            is_higher_better : bool
                Is eval result higher better, e.g. AUC is ``is_higher_better``.
485

486
        For binary task, the y_pred is probability of positive class (or margin in case of custom ``objective``).
487
488
        For multi-class task, the y_pred is group by class_id first, then group by row_id.
        If you want to get i-th row y_pred in j-th class, the access way is y_pred[j * num_data + i].
wxchan's avatar
wxchan committed
489
        """
490
491
492
493
494
495
496
497
498
499
        if self._objective is None:
            if isinstance(self, LGBMRegressor):
                self._objective = "regression"
            elif isinstance(self, LGBMClassifier):
                self._objective = "binary"
            elif isinstance(self, LGBMRanker):
                self._objective = "lambdarank"
            else:
                raise ValueError("Unknown LGBMModel type.")
        if callable(self._objective):
500
            self._fobj = _ObjectiveFunctionWrapper(self._objective)
501
502
        else:
            self._fobj = None
wxchan's avatar
wxchan committed
503
504
        evals_result = {}
        params = self.get_params()
wxchan's avatar
wxchan committed
505
        # user can set verbose with kwargs, it has higher priority
506
        if not any(verbose_alias in params for verbose_alias in _ConfigAliases.get("verbosity")) and self.silent:
507
            params['verbose'] = -1
wxchan's avatar
wxchan committed
508
        params.pop('silent', None)
509
        params.pop('importance_type', None)
wxchan's avatar
wxchan committed
510
        params.pop('n_estimators', None)
511
        params.pop('class_weight', None)
512
513
        if isinstance(params['random_state'], np.random.RandomState):
            params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)
514
515
        for alias in _ConfigAliases.get('objective'):
            params.pop(alias, None)
516
        if self._n_classes is not None and self._n_classes > 2:
517
518
            for alias in _ConfigAliases.get('num_class'):
                params.pop(alias, None)
519
520
            params['num_class'] = self._n_classes
        if hasattr(self, '_eval_at'):
521
522
            for alias in _ConfigAliases.get('eval_at'):
                params.pop(alias, None)
523
            params['eval_at'] = self._eval_at
524
525
        params['objective'] = self._objective
        if self._fobj:
wxchan's avatar
wxchan committed
526
            params['objective'] = 'None'  # objective = nullptr for unknown objective
wxchan's avatar
wxchan committed
527

528
529
530
531
532
533
534
535
        # Do not modify original args in fit function
        # Refer to https://github.com/microsoft/LightGBM/pull/2619
        eval_metric_list = copy.deepcopy(eval_metric)
        if not isinstance(eval_metric_list, list):
            eval_metric_list = [eval_metric_list]

        # Separate built-in from callable evaluation metrics
        eval_metrics_callable = [_EvalFunctionWrapper(f) for f in eval_metric_list if callable(f)]
536
        eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)]
537
538

        # register default metric for consistency with callable eval_metric case
539
        original_metric = self._objective if isinstance(self._objective, str) else None
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        if original_metric is None:
            # try to deduce from class instance
            if isinstance(self, LGBMRegressor):
                original_metric = "l2"
            elif isinstance(self, LGBMClassifier):
                original_metric = "multi_logloss" if self._n_classes > 2 else "binary_logloss"
            elif isinstance(self, LGBMRanker):
                original_metric = "ndcg"

        # overwrite default metric by explicitly set metric
        for metric_alias in _ConfigAliases.get("metric"):
            if metric_alias in params:
                original_metric = params.pop(metric_alias)

        # concatenate metric from params (or default if not provided in params) and eval_metric
555
        original_metric = [original_metric] if isinstance(original_metric, (str, type(None))) else original_metric
556
557
        params['metric'] = [e for e in eval_metrics_builtin if e not in original_metric] + original_metric
        params['metric'] = [metric for metric in params['metric'] if metric is not None]
wxchan's avatar
wxchan committed
558

559
        if not isinstance(X, (DataFrame, DataTable)):
560
            _X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
561
562
            if sample_weight is not None:
                sample_weight = _LGBMCheckSampleWeight(sample_weight, _X)
563
564
        else:
            _X, _y = X, y
565

566
567
568
569
        if self._class_weight is None:
            self._class_weight = self.class_weight
        if self._class_weight is not None:
            class_sample_weight = _LGBMComputeSampleWeight(self._class_weight, y)
570
571
572
573
            if sample_weight is None or len(sample_weight) == 0:
                sample_weight = class_sample_weight
            else:
                sample_weight = np.multiply(sample_weight, class_sample_weight)
574

575
        self._n_features = _X.shape[1]
576
577
        # copy for consistency
        self._n_features_in = self._n_features
578

579
580
        def _construct_dataset(X, y, sample_weight, init_score, group, params,
                               categorical_feature='auto'):
581
            return Dataset(X, label=y, weight=sample_weight, group=group,
582
583
                           init_score=init_score, params=params,
                           categorical_feature=categorical_feature)
Guolin Ke's avatar
Guolin Ke committed
584

585
586
        train_set = _construct_dataset(_X, _y, sample_weight, init_score, group, params,
                                       categorical_feature=categorical_feature)
Guolin Ke's avatar
Guolin Ke committed
587
588
589

        valid_sets = []
        if eval_set is not None:
590

591
            def _get_meta_data(collection, name, i):
592
593
594
595
596
597
598
                if collection is None:
                    return None
                elif isinstance(collection, list):
                    return collection[i] if len(collection) > i else None
                elif isinstance(collection, dict):
                    return collection.get(i, None)
                else:
599
                    raise TypeError('{} should be dict or list'.format(name))
600

Guolin Ke's avatar
Guolin Ke committed
601
602
603
            if isinstance(eval_set, tuple):
                eval_set = [eval_set]
            for i, valid_data in enumerate(eval_set):
604
                # reduce cost for prediction training data
Guolin Ke's avatar
Guolin Ke committed
605
606
607
                if valid_data[0] is X and valid_data[1] is y:
                    valid_set = train_set
                else:
608
609
610
611
612
613
                    valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i)
                    valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i)
                    if valid_class_weight is not None:
                        if isinstance(valid_class_weight, dict) and self._class_map is not None:
                            valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
                        valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
614
615
616
617
                        if valid_weight is None or len(valid_weight) == 0:
                            valid_weight = valid_class_sample_weight
                        else:
                            valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
618
619
                    valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i)
                    valid_group = _get_meta_data(eval_group, 'eval_group', i)
620
621
                    valid_set = _construct_dataset(valid_data[0], valid_data[1],
                                                   valid_weight, valid_init_score, valid_group, params)
Guolin Ke's avatar
Guolin Ke committed
622
623
                valid_sets.append(valid_set)

624
625
626
        if isinstance(init_model, LGBMModel):
            init_model = init_model.booster_

Guolin Ke's avatar
Guolin Ke committed
627
        self._Booster = train(params, train_set,
628
                              self.n_estimators, valid_sets=valid_sets, valid_names=eval_names,
wxchan's avatar
wxchan committed
629
                              early_stopping_rounds=early_stopping_rounds,
630
                              evals_result=evals_result, fobj=self._fobj, feval=eval_metrics_callable,
Guolin Ke's avatar
Guolin Ke committed
631
                              verbose_eval=verbose, feature_name=feature_name,
632
                              callbacks=callbacks, init_model=init_model)
wxchan's avatar
wxchan committed
633
634

        if evals_result:
635
            self._evals_result = evals_result
wxchan's avatar
wxchan committed
636

637
        if early_stopping_rounds is not None and early_stopping_rounds > 0:
638
            self._best_iteration = self._Booster.best_iteration
639
640

        self._best_score = self._Booster.best_score
wxchan's avatar
wxchan committed
641

642
643
        self.fitted_ = True

wxchan's avatar
wxchan committed
644
        # free dataset
645
        self._Booster.free_dataset()
wxchan's avatar
wxchan committed
646
        del train_set, valid_sets
wxchan's avatar
wxchan committed
647
648
        return self

649
    def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
650
                pred_leaf=False, pred_contrib=False, **kwargs):
651
        """Return the predicted value for each sample.
wxchan's avatar
wxchan committed
652
653
654

        Parameters
        ----------
655
        X : array-like or sparse matrix of shape = [n_samples, n_features]
wxchan's avatar
wxchan committed
656
            Input features matrix.
657
658
        raw_score : bool, optional (default=False)
            Whether to predict raw scores.
659
        start_iteration : int, optional (default=0)
660
            Start index of the iteration to predict.
661
            If <= 0, starts from the first iteration.
662
        num_iteration : int or None, optional (default=None)
663
664
665
666
            Total number of iterations used in the prediction.
            If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
            otherwise, all iterations from ``start_iteration`` are used (no limits).
            If <= 0, all iterations from ``start_iteration`` are used (no limits).
667
668
669
670
        pred_leaf : bool, optional (default=False)
            Whether to predict leaf index.
        pred_contrib : bool, optional (default=False)
            Whether to predict feature contributions.
671

Nikita Titov's avatar
Nikita Titov committed
672
673
674
675
676
677
678
            .. note::

                If you want to get more explanations for your model's predictions using SHAP values,
                like SHAP interaction values,
                you can install the shap package (https://github.com/slundberg/shap).
                Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra
                column, where the last column is the expected value.
679

680
681
        **kwargs
            Other parameters for the prediction.
wxchan's avatar
wxchan committed
682
683
684

        Returns
        -------
685
686
        predicted_result : array-like of shape = [n_samples] or shape = [n_samples, n_classes]
            The predicted values.
687
        X_leaves : array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
Nikita Titov's avatar
Nikita Titov committed
688
            If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
689
        X_SHAP_values : array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects
690
            If ``pred_contrib=True``, the feature contributions for each sample.
wxchan's avatar
wxchan committed
691
        """
692
693
        if self._n_features is None:
            raise LGBMNotFittedError("Estimator not fitted, call `fit` before exploiting the model.")
694
        if not isinstance(X, (DataFrame, DataTable)):
695
            X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
696
697
698
699
700
701
        n_features = X.shape[1]
        if self._n_features != n_features:
            raise ValueError("Number of features of the model must "
                             "match the input. Model n_features_ is %s and "
                             "input n_features is %s "
                             % (self._n_features, n_features))
702
        return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration,
703
                                     pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs)
wxchan's avatar
wxchan committed
704

705
706
    @property
    def n_features_(self):
707
        """:obj:`int`: The number of features of fitted model."""
708
709
710
711
        if self._n_features is None:
            raise LGBMNotFittedError('No n_features found. Need to call fit beforehand.')
        return self._n_features

712
713
714
715
716
717
718
    @property
    def n_features_in_(self):
        """:obj:`int`: The number of features of fitted model."""
        if self._n_features_in is None:
            raise LGBMNotFittedError('No n_features_in found. Need to call fit beforehand.')
        return self._n_features_in

719
720
    @property
    def best_score_(self):
721
        """:obj:`dict` or :obj:`None`: The best score of fitted model."""
722
723
724
725
726
727
        if self._n_features is None:
            raise LGBMNotFittedError('No best_score found. Need to call fit beforehand.')
        return self._best_score

    @property
    def best_iteration_(self):
728
        """:obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping_rounds`` has been specified."""
729
730
731
732
733
734
        if self._n_features is None:
            raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping_rounds beforehand.')
        return self._best_iteration

    @property
    def objective_(self):
735
        """:obj:`string` or :obj:`callable`: The concrete objective used while fitting this model."""
736
737
738
739
        if self._n_features is None:
            raise LGBMNotFittedError('No objective found. Need to call fit beforehand.')
        return self._objective

740
741
    @property
    def booster_(self):
742
        """Booster: The underlying Booster of this model."""
743
        if self._Booster is None:
744
            raise LGBMNotFittedError('No booster found. Need to call fit beforehand.')
745
        return self._Booster
wxchan's avatar
wxchan committed
746

747
748
    @property
    def evals_result_(self):
749
        """:obj:`dict` or :obj:`None`: The evaluation results if ``early_stopping_rounds`` has been specified."""
750
751
752
        if self._n_features is None:
            raise LGBMNotFittedError('No results found. Need to call fit with eval_set beforehand.')
        return self._evals_result
753
754

    @property
755
    def feature_importances_(self):
756
        """:obj:`array` of shape = [n_features]: The feature importances (the higher, the more important).
757

Nikita Titov's avatar
Nikita Titov committed
758
759
760
761
        .. note::

            ``importance_type`` attribute is passed to the function
            to configure the type of importance values to be extracted.
762
        """
763
764
        if self._n_features is None:
            raise LGBMNotFittedError('No feature_importances found. Need to call fit beforehand.')
765
        return self._Booster.feature_importance(importance_type=self.importance_type)
wxchan's avatar
wxchan committed
766

767
768
    @property
    def feature_name_(self):
769
        """:obj:`array` of shape = [n_features]: The names of features."""
770
771
772
773
        if self._n_features is None:
            raise LGBMNotFittedError('No feature_name found. Need to call fit beforehand.')
        return self._Booster.feature_name()

wxchan's avatar
wxchan committed
774

775
776
class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
    """LightGBM regressor."""
wxchan's avatar
wxchan committed
777

Guolin Ke's avatar
Guolin Ke committed
778
779
    def fit(self, X, y,
            sample_weight=None, init_score=None,
780
            eval_set=None, eval_names=None, eval_sample_weight=None,
781
            eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
782
783
            verbose=True, feature_name='auto', categorical_feature='auto',
            callbacks=None, init_model=None):
784
        """Docstring is inherited from the LGBMModel."""
785
786
787
788
789
        super().fit(X, y, sample_weight=sample_weight, init_score=init_score,
                    eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight,
                    eval_init_score=eval_init_score, eval_metric=eval_metric,
                    early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name,
                    categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model)
Guolin Ke's avatar
Guolin Ke committed
790
791
        return self

792
    _base_doc = LGBMModel.fit.__doc__
793
794
795
796
797
798
    _base_doc = (_base_doc[:_base_doc.find('group :')]
                 + _base_doc[_base_doc.find('eval_set :'):])
    _base_doc = (_base_doc[:_base_doc.find('eval_class_weight :')]
                 + _base_doc[_base_doc.find('eval_init_score :'):])
    fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')]
                   + _base_doc[_base_doc.find('eval_metric :'):])
wxchan's avatar
wxchan committed
799

800
801
802

class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
    """LightGBM classifier."""
wxchan's avatar
wxchan committed
803

Guolin Ke's avatar
Guolin Ke committed
804
805
    def fit(self, X, y,
            sample_weight=None, init_score=None,
806
            eval_set=None, eval_names=None, eval_sample_weight=None,
807
            eval_class_weight=None, eval_init_score=None, eval_metric=None,
wxchan's avatar
wxchan committed
808
            early_stopping_rounds=None, verbose=True,
809
810
            feature_name='auto', categorical_feature='auto',
            callbacks=None, init_model=None):
811
        """Docstring is inherited from the LGBMModel."""
812
        _LGBMAssertAllFinite(y)
813
814
        _LGBMCheckClassificationTargets(y)
        self._le = _LGBMLabelEncoder().fit(y)
815
        _y = self._le.transform(y)
816
        self._class_map = dict(zip(self._le.classes_, self._le.transform(self._le.classes_)))
817
818
        if isinstance(self.class_weight, dict):
            self._class_weight = {self._class_map[k]: v for k, v in self.class_weight.items()}
819

820
821
        self._classes = self._le.classes_
        self._n_classes = len(self._classes)
822

823
        if self._n_classes > 2:
wxchan's avatar
wxchan committed
824
            # Switch to using a multiclass objective in the underlying LGBM instance
825
            ova_aliases = {"multiclassova", "multiclass_ova", "ova", "ovr"}
826
            if self._objective not in ova_aliases and not callable(self._objective):
827
                self._objective = "multiclass"
828
829

        if not callable(eval_metric):
830
            if isinstance(eval_metric, (str, type(None))):
831
832
833
834
835
836
837
838
839
840
841
842
843
                eval_metric = [eval_metric]
            if self._n_classes > 2:
                for index, metric in enumerate(eval_metric):
                    if metric in {'logloss', 'binary_logloss'}:
                        eval_metric[index] = "multi_logloss"
                    elif metric in {'error', 'binary_error'}:
                        eval_metric[index] = "multi_error"
            else:
                for index, metric in enumerate(eval_metric):
                    if metric in {'logloss', 'multi_logloss'}:
                        eval_metric[index] = 'binary_logloss'
                    elif metric in {'error', 'multi_error'}:
                        eval_metric[index] = 'binary_error'
wxchan's avatar
wxchan committed
844

845
846
        # do not modify args, as it causes errors in model selection tools
        valid_sets = None
wxchan's avatar
wxchan committed
847
        if eval_set is not None:
848
849
            if isinstance(eval_set, tuple):
                eval_set = [eval_set]
850
            valid_sets = [None] * len(eval_set)
851
852
            for i, (valid_x, valid_y) in enumerate(eval_set):
                if valid_x is X and valid_y is y:
853
                    valid_sets[i] = (valid_x, _y)
854
                else:
855
                    valid_sets[i] = (valid_x, self._le.transform(valid_y))
856

857
858
859
860
861
862
        super().fit(X, _y, sample_weight=sample_weight, init_score=init_score, eval_set=valid_sets,
                    eval_names=eval_names, eval_sample_weight=eval_sample_weight,
                    eval_class_weight=eval_class_weight, eval_init_score=eval_init_score,
                    eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds,
                    verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature,
                    callbacks=callbacks, init_model=init_model)
wxchan's avatar
wxchan committed
863
864
        return self

865
866
867
868
869
    _base_doc = LGBMModel.fit.__doc__
    _base_doc = (_base_doc[:_base_doc.find('group :')]
                 + _base_doc[_base_doc.find('eval_set :'):])
    fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')]
                   + _base_doc[_base_doc.find('eval_metric :'):])
870

871
    def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
872
                pred_leaf=False, pred_contrib=False, **kwargs):
873
        """Docstring is inherited from the LGBMModel."""
874
        result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
875
                                    pred_leaf, pred_contrib, **kwargs)
876
        if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
877
878
879
880
            return result
        else:
            class_index = np.argmax(result, axis=1)
            return self._le.inverse_transform(class_index)
wxchan's avatar
wxchan committed
881

882
883
    predict.__doc__ = LGBMModel.predict.__doc__

884
    def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None,
885
                      pred_leaf=False, pred_contrib=False, **kwargs):
886
        """Return the predicted probability for each class for each sample.
wxchan's avatar
wxchan committed
887
888
889

        Parameters
        ----------
890
        X : array-like or sparse matrix of shape = [n_samples, n_features]
wxchan's avatar
wxchan committed
891
            Input features matrix.
892
893
        raw_score : bool, optional (default=False)
            Whether to predict raw scores.
894
        start_iteration : int, optional (default=0)
895
            Start index of the iteration to predict.
896
            If <= 0, starts from the first iteration.
897
        num_iteration : int or None, optional (default=None)
898
899
900
901
            Total number of iterations used in the prediction.
            If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
            otherwise, all iterations from ``start_iteration`` are used (no limits).
            If <= 0, all iterations from ``start_iteration`` are used (no limits).
902
903
904
905
        pred_leaf : bool, optional (default=False)
            Whether to predict leaf index.
        pred_contrib : bool, optional (default=False)
            Whether to predict feature contributions.
906

Nikita Titov's avatar
Nikita Titov committed
907
908
909
910
911
912
913
            .. note::

                If you want to get more explanations for your model's predictions using SHAP values,
                like SHAP interaction values,
                you can install the shap package (https://github.com/slundberg/shap).
                Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra
                column, where the last column is the expected value.
914

915
916
        **kwargs
            Other parameters for the prediction.
wxchan's avatar
wxchan committed
917
918
919

        Returns
        -------
920
921
        predicted_probability : array-like of shape = [n_samples, n_classes]
            The predicted probability for each class for each sample.
922
        X_leaves : array-like of shape = [n_samples, n_trees * n_classes]
923
            If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
924
        X_SHAP_values : array-like of shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects
925
            If ``pred_contrib=True``, the feature contributions for each sample.
wxchan's avatar
wxchan committed
926
        """
927
        result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, **kwargs)
928
929
930
931
932
933
        if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
            warnings.warn("Cannot compute class probabilities or labels "
                          "due to the usage of customized objective function.\n"
                          "Returning raw scores instead.")
            return result
        elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib:
934
            return result
wxchan's avatar
wxchan committed
935
        else:
936
            return np.vstack((1. - result, result)).transpose()
937
938
939

    @property
    def classes_(self):
940
        """:obj:`array` of shape = [n_classes]: The class label array."""
941
942
943
        if self._classes is None:
            raise LGBMNotFittedError('No classes found. Need to call fit beforehand.')
        return self._classes
944
945
946

    @property
    def n_classes_(self):
947
        """:obj:`int`: The number of classes."""
948
949
950
        if self._n_classes is None:
            raise LGBMNotFittedError('No classes found. Need to call fit beforehand.')
        return self._n_classes
wxchan's avatar
wxchan committed
951

wxchan's avatar
wxchan committed
952

wxchan's avatar
wxchan committed
953
class LGBMRanker(LGBMModel):
954
    """LightGBM ranker."""
wxchan's avatar
wxchan committed
955

Guolin Ke's avatar
Guolin Ke committed
956
    def fit(self, X, y,
957
            sample_weight=None, init_score=None, group=None,
958
            eval_set=None, eval_names=None, eval_sample_weight=None,
959
            eval_init_score=None, eval_group=None, eval_metric=None,
960
            eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose=True,
961
962
            feature_name='auto', categorical_feature='auto',
            callbacks=None, init_model=None):
963
        """Docstring is inherited from the LGBMModel."""
964
        # check group data
Guolin Ke's avatar
Guolin Ke committed
965
        if group is None:
966
            raise ValueError("Should set group for ranking task")
wxchan's avatar
wxchan committed
967
968

        if eval_set is not None:
Guolin Ke's avatar
Guolin Ke committed
969
            if eval_group is None:
970
                raise ValueError("Eval_group cannot be None when eval_set is not None")
Guolin Ke's avatar
Guolin Ke committed
971
            elif len(eval_group) != len(eval_set):
972
                raise ValueError("Length of eval_group should be equal to eval_set")
973
            elif (isinstance(eval_group, dict)
974
                  and any(i not in eval_group or eval_group[i] is None for i in range(len(eval_group)))
975
976
                  or isinstance(eval_group, list)
                  and any(group is None for group in eval_group)):
977
978
                raise ValueError("Should set group for all eval datasets for ranking task; "
                                 "if you use dict, the index should start from 0")
979

980
        self._eval_at = eval_at
981
982
983
984
985
        super().fit(X, y, sample_weight=sample_weight, init_score=init_score, group=group,
                    eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight,
                    eval_init_score=eval_init_score, eval_group=eval_group, eval_metric=eval_metric,
                    early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name,
                    categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model)
wxchan's avatar
wxchan committed
986
        return self
987

988
989
990
991
    _base_doc = LGBMModel.fit.__doc__
    fit.__doc__ = (_base_doc[:_base_doc.find('eval_class_weight :')]
                   + _base_doc[_base_doc.find('eval_init_score :'):])
    _base_doc = fit.__doc__
992
993
    _before_early_stop, _early_stop, _after_early_stop = _base_doc.partition('early_stopping_rounds :')
    fit.__doc__ = (_before_early_stop
994
                   + 'eval_at : iterable of int, optional (default=(1, 2, 3, 4, 5))\n'
995
996
                   + ' ' * 12 + 'The evaluation positions of the specified metric.\n'
                   + ' ' * 8 + _early_stop + _after_early_stop)