dask.py 37.1 KB
Newer Older
1
# coding: utf-8
2
"""Distributed training with LightGBM and dask.distributed.
3

4
This module enables you to perform distributed training with LightGBM on
5
dask.Array and dask.DataFrame collections.
6
7

It is based on dask-lightgbm, which was based on dask-xgboost.
8
"""
9
import socket
10
from collections import defaultdict
11
from copy import deepcopy
12
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
13
14
15
from urllib.parse import urlparse

import numpy as np
16
17
import scipy.sparse as ss

18
from .basic import _LIB, LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning, _safe_call
19
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
20
                     dask_Array, dask_DataFrame, dask_Series, default_client, delayed, pd_DataFrame, pd_Series, wait)
21
from .sklearn import LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict
22
23
24
25
26

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def _get_dask_client(client: Optional[Client]) -> Client:
    """Choose a Dask client to use.

    Parameters
    ----------
    client : dask.distributed.Client or None
        Dask client.

    Returns
    -------
    client : dask.distributed.Client
        A Dask client.
    """
    if client is None:
        return default_client()
    else:
        return client


48
49
def _find_random_open_port() -> int:
    """Find a random open port on localhost.
50
51
52

    Returns
    -------
53
    port : int
54
        A free port on localhost
55
    """
56
57
58
59
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        port = s.getsockname()[1]
    return port
60
61


62
def _concat(seq: List[_DaskPart]) -> _DaskPart:
63
64
    if isinstance(seq[0], np.ndarray):
        return np.concatenate(seq, axis=0)
65
    elif isinstance(seq[0], (pd_DataFrame, pd_Series)):
66
        return concat(seq, axis=0)
67
68
69
70
71
72
    elif isinstance(seq[0], ss.spmatrix):
        return ss.vstack(seq, format='csr')
    else:
        raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0])))


73
74
75
76
def _train_part(
    params: Dict[str, Any],
    model_factory: Type[LGBMModel],
    list_of_parts: List[Dict[str, _DaskPart]],
77
78
79
    machines: str,
    local_listen_port: int,
    num_machines: int,
80
81
82
83
    return_model: bool,
    time_out: int = 120,
    **kwargs: Any
) -> Optional[LGBMModel]:
84
    network_params = {
85
86
        'machines': machines,
        'local_listen_port': local_listen_port,
87
        'time_out': time_out,
88
        'num_machines': num_machines
89
    }
90
91
    params.update(network_params)

92
93
    is_ranker = issubclass(model_factory, LGBMRanker)

94
    # Concatenate many parts into one
95
96
97
98
99
100
101
102
103
104
105
106
    data = _concat([x['data'] for x in list_of_parts])
    label = _concat([x['label'] for x in list_of_parts])

    if 'weight' in list_of_parts[0]:
        weight = _concat([x['weight'] for x in list_of_parts])
    else:
        weight = None

    if 'group' in list_of_parts[0]:
        group = _concat([x['group'] for x in list_of_parts])
    else:
        group = None
107
108
109

    try:
        model = model_factory(**params)
110
        if is_ranker:
111
            model.fit(data, label, sample_weight=weight, group=group, **kwargs)
112
        else:
113
            model.fit(data, label, sample_weight=weight, **kwargs)
114

115
116
117
118
119
120
    finally:
        _safe_call(_LIB.LGBM_NetworkFree())

    return model if return_model else None


121
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
122
123
    parts = data.to_delayed()
    if isinstance(parts, np.ndarray):
124
125
126
127
        if is_matrix:
            assert parts.shape[1] == 1
        else:
            assert parts.ndim == 1 or parts.shape[1] == 1
128
129
130
131
        parts = parts.flatten().tolist()
    return parts


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[str, int]:
    """Create a worker_map from machines list.

    Given ``machines`` and a list of Dask worker addresses, return a mapping where the keys are
    ``worker_addresses`` and the values are ports from ``machines``.

    Parameters
    ----------
    machines : str
        A comma-delimited list of workers, of the form ``ip1:port,ip2:port``.
    worker_addresses : list of str
        A list of Dask worker addresses, of the form ``{protocol}{hostname}:{port}``, where ``port`` is the port Dask's scheduler uses to talk to that worker.

    Returns
    -------
    result : Dict[str, int]
        Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
    """
    machine_addresses = machines.split(",")
    machine_to_port = defaultdict(set)
    for address in machine_addresses:
        host, port = address.split(":")
        machine_to_port[host].add(int(port))

    out = {}
    for address in worker_addresses:
        worker_host = urlparse(address).hostname
        out[address] = machine_to_port[worker_host].pop()

    return out


164
165
166
167
168
169
170
171
172
173
def _train(
    client: Client,
    data: _DaskMatrixLike,
    label: _DaskCollection,
    params: Dict[str, Any],
    model_factory: Type[LGBMModel],
    sample_weight: Optional[_DaskCollection] = None,
    group: Optional[_DaskCollection] = None,
    **kwargs: Any
) -> LGBMModel:
174
175
176
177
    """Inner train routine.

    Parameters
    ----------
178
179
    client : dask.distributed.Client
        Dask client.
180
    data : Dask Array or Dask DataFrame of shape = [n_samples, n_features]
181
        Input feature matrix.
182
    label : Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]
183
184
        The target values (class labels in classification, real numbers in regression).
    params : dict
185
        Parameters passed to constructor of the local underlying model.
186
    model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
187
        Class of the local underlying model.
188
    sample_weight : Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
189
        Weights of training data.
190
    group : Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
191
192
193
194
195
        Group/query data.
        Only used in the learning-to-rank task.
        sum(group) = n_samples.
        For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
        where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
196
197
198
199
200
201
202
    **kwargs
        Other parameters passed to ``fit`` method of the local underlying model.

    Returns
    -------
    model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
        Returns fitted underlying model.
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

    Note
    ----

    This method handles setting up the following network parameters based on information
    about the Dask cluster referenced by ``client``.

    * ``local_listen_port``: port that each LightGBM worker opens a listening socket on,
            to accept connections from other workers. This can differ from LightGBM worker
            to LightGBM worker, but does not have to.
    * ``machines``: a comma-delimited list of all workers in the cluster, in the
            form ``ip:port,ip:port``. If running multiple Dask workers on the same host, use different
            ports for each worker. For example, for ``LocalCluster(n_workers=3)``, you might
            pass ``"127.0.0.1:12400,127.0.0.1:12401,127.0.0.1:12402"``.
    * ``num_machines``: number of LightGBM workers.
    * ``timeout``: time in minutes to wait before closing unused sockets.

    The default behavior of this function is to generate ``machines`` from the list of
    Dask workers which hold some piece of the training data, and to search for an open
    port on each worker to be used as ``local_listen_port``.

    If ``machines`` is provided explicitly in ``params``, this function uses the hosts
    and ports in that list directly, and does not do any searching. This means that if
    any of the Dask workers are missing from the list or any of those ports are not free
    when training starts, training will fail.

    If ``local_listen_port`` is provided in ``params`` and ``machines`` is not, this function
    constructs ``machines`` from the list of Dask workers which hold some piece of the
    training data, assuming that each one will use the same ``local_listen_port``.
232
    """
233
234
    params = deepcopy(params)

235
236
237
238
239
240
241
242
    # capture whether local_listen_port or its aliases were provided
    listen_port_in_params = any(
        alias in params for alias in _ConfigAliases.get("local_listen_port")
    )

    # capture whether machines or its aliases were provided
    machines_in_params = any(
        alias in params for alias in _ConfigAliases.get("machines")
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    )

    params = _choose_param_value(
        main_param_name="tree_learner",
        params=params,
        default_value="data"
    )
    allowed_tree_learners = {
        'data',
        'data_parallel',
        'feature',
        'feature_parallel',
        'voting',
        'voting_parallel'
    }
    if params["tree_learner"] not in allowed_tree_learners:
259
        _log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % params['tree_learner'])
260
261
262
263
264
265
266
267
268
269
270
        params['tree_learner'] = 'data'

    if params['tree_learner'] not in {'data', 'data_parallel'}:
        _log_warning(
            'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. \n'
            'Use "data" for a stable, well-tested interface.' % params['tree_learner']
        )

    # Some passed-in parameters can be removed:
    #   * 'num_machines': set automatically from Dask worker list
    #   * 'num_threads': overridden to match nthreads on each Dask process
271
272
273
274
    for param_alias in _ConfigAliases.get('num_machines', 'num_threads'):
        if param_alias in params:
            _log_warning(f"Parameter {param_alias} will be ignored.")
            params.pop(param_alias)
275

276
    # Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
277
278
    data_parts = _split_to_parts(data=data, is_matrix=True)
    label_parts = _split_to_parts(data=label, is_matrix=False)
279
    parts = [{'data': x, 'label': y} for (x, y) in zip(data_parts, label_parts)]
280
    n_parts = len(parts)
281
282
283

    if sample_weight is not None:
        weight_parts = _split_to_parts(data=sample_weight, is_matrix=False)
284
        for i in range(n_parts):
285
            parts[i]['weight'] = weight_parts[i]
286
287
288

    if group is not None:
        group_parts = _split_to_parts(data=group, is_matrix=False)
289
        for i in range(n_parts):
290
            parts[i]['group'] = group_parts[i]
291
292

    # Start computation in the background
293
    parts = list(map(delayed, parts))
294
295
296
297
    parts = client.compute(parts)
    wait(parts)

    for part in parts:
298
        if part.status == 'error':  # type: ignore
299
300
301
            return part  # trigger error locally

    # Find locations of all parts and map them to particular Dask workers
302
    key_to_part_dict = {part.key: part for part in parts}  # type: ignore
303
304
305
306
307
308
309
310
    who_has = client.who_has(parts)
    worker_map = defaultdict(list)
    for key, workers in who_has.items():
        worker_map[next(iter(workers))].append(key_to_part_dict[key])

    master_worker = next(iter(worker_map))
    worker_ncores = client.ncores()

311
312
313
314
315
316
    # resolve aliases for network parameters and pop the result off params.
    # these values are added back in calls to `_train_part()`
    params = _choose_param_value(
        main_param_name="local_listen_port",
        params=params,
        default_value=12400
317
    )
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    local_listen_port = params.pop("local_listen_port")

    params = _choose_param_value(
        main_param_name="machines",
        params=params,
        default_value=None
    )
    machines = params.pop("machines")

    # figure out network params
    worker_addresses = worker_map.keys()
    if machines is not None:
        _log_info("Using passed-in 'machines' parameter")
        worker_address_to_port = _machines_to_worker_map(
            machines=machines,
            worker_addresses=worker_addresses
        )
    else:
        if listen_port_in_params:
            _log_info("Using passed-in 'local_listen_port' for all workers")
            unique_hosts = set(urlparse(a).hostname for a in worker_addresses)
            if len(unique_hosts) < len(worker_addresses):
                msg = (
                    "'local_listen_port' was provided in Dask training parameters, but at least one "
                    "machine in the cluster has multiple Dask worker processes running on it. Please omit "
                    "'local_listen_port' or pass 'machines'."
                )
                raise LightGBMError(msg)

            worker_address_to_port = {
                address: local_listen_port
                for address in worker_addresses
            }
        else:
            _log_info("Finding random open ports for workers")
353
354
355
            worker_address_to_port = client.run(
                _find_random_open_port,
                workers=list(worker_addresses)
356
357
358
359
360
361
362
363
            )
        machines = ','.join([
            '%s:%d' % (urlparse(worker_address).hostname, port)
            for worker_address, port
            in worker_address_to_port.items()
        ])

    num_machines = len(worker_address_to_port)
364

365
    # Tell each worker to train on the parts that it has locally
366
367
368
369
370
371
    futures_classifiers = [
        client.submit(
            _train_part,
            model_factory=model_factory,
            params={**params, 'num_threads': worker_ncores[worker]},
            list_of_parts=list_of_parts,
372
373
374
            machines=machines,
            local_listen_port=worker_address_to_port[worker],
            num_machines=num_machines,
375
376
377
378
379
380
            time_out=params.get('time_out', 120),
            return_model=(worker == master_worker),
            **kwargs
        )
        for worker, list_of_parts in worker_map.items()
    ]
381
382
383

    results = client.gather(futures_classifiers)
    results = [v for v in results if v]
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    model = results[0]

    # if network parameters were changed during training, remove them from the
    # returned moodel so that they're generated dynamically on every run based
    # on the Dask cluster you're connected to and which workers have pieces of
    # the training data
    if not listen_port_in_params:
        for param in _ConfigAliases.get('local_listen_port'):
            model._other_params.pop(param, None)

    if not machines_in_params:
        for param in _ConfigAliases.get('machines'):
            model._other_params.pop(param, None)

    for param in _ConfigAliases.get('num_machines', 'timeout'):
        model._other_params.pop(param, None)

    return model
402
403


404
405
406
407
408
409
410
411
412
def _predict_part(
    part: _DaskPart,
    model: LGBMModel,
    raw_score: bool,
    pred_proba: bool,
    pred_leaf: bool,
    pred_contrib: bool,
    **kwargs: Any
) -> _DaskPart:
413

414
    if part.shape[0] == 0:
415
        result = np.array([])
416
417
    elif pred_proba:
        result = model.predict_proba(
418
            part,
419
420
421
422
423
            raw_score=raw_score,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            **kwargs
        )
424
    else:
425
        result = model.predict(
426
            part,
427
428
429
430
431
            raw_score=raw_score,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            **kwargs
        )
432

433
    # dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series
434
    if isinstance(part, pd_DataFrame):
435
        if pred_proba or pred_contrib or pred_leaf:
436
            result = pd_DataFrame(result, index=part.index)
437
        else:
438
            result = pd_Series(result, index=part.index, name='predictions')
439
440
441
442

    return result


443
444
445
446
447
448
449
450
451
452
def _predict(
    model: LGBMModel,
    data: _DaskMatrixLike,
    raw_score: bool = False,
    pred_proba: bool = False,
    pred_leaf: bool = False,
    pred_contrib: bool = False,
    dtype: _PredictionDtype = np.float32,
    **kwargs: Any
) -> dask_Array:
453
454
455
456
    """Inner predict routine.

    Parameters
    ----------
457
    model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
458
        Fitted underlying model.
459
    data : Dask Array or Dask DataFrame of shape = [n_samples, n_features]
460
        Input feature matrix.
461
462
    raw_score : bool, optional (default=False)
        Whether to predict raw scores.
463
464
465
466
467
468
    pred_proba : bool, optional (default=False)
        Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``).
    pred_leaf : bool, optional (default=False)
        Whether to predict leaf index.
    pred_contrib : bool, optional (default=False)
        Whether to predict feature contributions.
469
    dtype : np.dtype, optional (default=np.float32)
470
        Dtype of the output.
471
    **kwargs
472
        Other parameters passed to ``predict`` or ``predict_proba`` method.
473
474
475

    Returns
    -------
476
    predicted_result : Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]
477
        The predicted values.
478
    X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
479
        If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
480
    X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]
481
        If ``pred_contrib=True``, the feature contributions for each sample.
482
    """
483
484
    if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
        raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
485
    if isinstance(data, dask_DataFrame):
486
487
488
489
490
491
492
493
494
        return data.map_partitions(
            _predict_part,
            model=model,
            raw_score=raw_score,
            pred_proba=pred_proba,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            **kwargs
        ).values
495
    elif isinstance(data, dask_Array):
496
        if pred_proba:
497
498
499
            kwargs['chunks'] = (data.chunks[0], (model.n_classes_,))
        else:
            kwargs['drop_axis'] = 1
500
501
502
503
504
505
506
507
508
509
        return data.map_blocks(
            _predict_part,
            model=model,
            raw_score=raw_score,
            pred_proba=pred_proba,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            dtype=dtype,
            **kwargs
        )
510
    else:
511
        raise TypeError('Data must be either Dask Array or Dask DataFrame. Got %s.' % str(type(data)))
512
513


514
class _DaskLGBMModel:
515

516
517
    @property
    def client_(self) -> Client:
518
        """:obj:`dask.distributed.Client`: Dask client.
519
520
521
522
523
524
525
526
527

        This property can be passed in the constructor or updated
        with ``model.set_params(client=client)``.
        """
        if not getattr(self, "fitted_", False):
            raise LGBMNotFittedError('Cannot access property client_ before calling fit().')

        return _get_dask_client(client=self.client)

528
    def _lgb_dask_getstate(self) -> Dict[Any, Any]:
529
530
531
532
        """Remove un-picklable attributes before serialization."""
        client = self.__dict__.pop("client", None)
        self._other_params.pop("client", None)
        out = deepcopy(self.__dict__)
533
        out.update({"client": None})
534
535
536
        self.client = client
        return out

537
    def _lgb_dask_fit(
538
539
540
541
542
543
544
545
        self,
        model_factory: Type[LGBMModel],
        X: _DaskMatrixLike,
        y: _DaskCollection,
        sample_weight: Optional[_DaskCollection] = None,
        group: Optional[_DaskCollection] = None,
        **kwargs: Any
    ) -> "_DaskLGBMModel":
546
547
        if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
            raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
548
549

        params = self.get_params(True)
550
        params.pop("client", None)
551
552

        model = _train(
553
            client=_get_dask_client(self.client),
554
555
556
557
558
559
560
561
            data=X,
            label=y,
            params=params,
            model_factory=model_factory,
            sample_weight=sample_weight,
            group=group,
            **kwargs
        )
562
563

        self.set_params(**model.get_params())
564
        self._lgb_dask_copy_extra_params(model, self)
565
566
567

        return self

568
    def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
569
570
571
        params = self.get_params()
        params.pop("client", None)
        model = model_factory(**params)
572
        self._lgb_dask_copy_extra_params(self, model)
573
        model._other_params.pop("client", None)
574
575
576
        return model

    @staticmethod
577
    def _lgb_dask_copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
578
579
580
581
        params = source.get_params()
        attributes = source.__dict__
        extra_param_names = set(attributes.keys()).difference(params.keys())
        for name in extra_param_names:
582
            setattr(dest, name, attributes[name])
583
584


585
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
586
587
    """Distributed version of lightgbm.LGBMClassifier."""

588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    def __init__(
        self,
        boosting_type: str = 'gbdt',
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
        objective: Optional[Union[Callable, str]] = None,
        class_weight: Optional[Union[dict, str]] = None,
        min_split_gain: float = 0.,
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
        subsample: float = 1.,
        subsample_freq: int = 0,
        colsample_bytree: float = 1.,
        reg_alpha: float = 0.,
        reg_lambda: float = 0.,
        random_state: Optional[Union[int, np.random.RandomState]] = None,
        n_jobs: int = -1,
        silent: bool = True,
        importance_type: str = 'split',
        client: Optional[Client] = None,
        **kwargs: Any
    ):
        """Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            silent=silent,
            importance_type=importance_type,
            **kwargs
        )

    _base_doc = LGBMClassifier.__init__.__doc__
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
641
    _base_doc = (
642
643
644
645
646
647
        _before_kwargs
        + 'client : dask.distributed.Client or None, optional (default=None)\n'
        + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
        + ' ' * 8 + _kwargs + _after_kwargs
    )

648
649
650
651
    # the note on custom objective functions in LGBMModel.__init__ is not
    # currently relevant for the Dask estimators
    __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

652
    def __getstate__(self) -> Dict[Any, Any]:
653
        return self._lgb_dask_getstate()
654

655
656
657
658
659
660
661
    def fit(
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
        sample_weight: Optional[_DaskCollection] = None,
        **kwargs: Any
    ) -> "DaskLGBMClassifier":
662
        """Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
663
        return self._lgb_dask_fit(
664
665
666
667
668
669
670
            model_factory=LGBMClassifier,
            X=X,
            y=y,
            sample_weight=sample_weight,
            **kwargs
        )

671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
        sample_weight_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
        group_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)"
    )

    # DaskLGBMClassifier does not support init_score, evaluation data, or early stopping
    _base_doc = (_base_doc[:_base_doc.find('init_score :')]
                 + _base_doc[_base_doc.find('verbose :'):])

    # DaskLGBMClassifier support for callbacks and init_model is not tested
    fit.__doc__ = (
        _base_doc[:_base_doc.find('callbacks :')]
        + '**kwargs\n'
686
        + ' ' * 12 + 'Other parameters passed through to ``LGBMClassifier.fit()``.\n'
687
    )
688

689
    def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
690
        """Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
691
692
693
694
695
696
697
        return _predict(
            model=self.to_local(),
            data=X,
            dtype=self.classes_.dtype,
            **kwargs
        )

698
699
700
701
702
703
704
705
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]"
    )
706

707
    def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
708
        """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
709
710
711
712
713
714
715
        return _predict(
            model=self.to_local(),
            data=X,
            pred_proba=True,
            **kwargs
        )

716
717
718
719
    predict_proba.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted probability for each class for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_probability",
720
        predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
721
722
723
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]"
    )
724

725
    def to_local(self) -> LGBMClassifier:
726
727
728
729
730
        """Create regular version of lightgbm.LGBMClassifier from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMClassifier
731
            Local underlying model.
732
        """
733
        return self._lgb_dask_to_local(LGBMClassifier)
734
735


736
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
737
    """Distributed version of lightgbm.LGBMRegressor."""
738

739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
    def __init__(
        self,
        boosting_type: str = 'gbdt',
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
        objective: Optional[Union[Callable, str]] = None,
        class_weight: Optional[Union[dict, str]] = None,
        min_split_gain: float = 0.,
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
        subsample: float = 1.,
        subsample_freq: int = 0,
        colsample_bytree: float = 1.,
        reg_alpha: float = 0.,
        reg_lambda: float = 0.,
        random_state: Optional[Union[int, np.random.RandomState]] = None,
        n_jobs: int = -1,
        silent: bool = True,
        importance_type: str = 'split',
        client: Optional[Client] = None,
        **kwargs: Any
    ):
        """Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            silent=silent,
            importance_type=importance_type,
            **kwargs
        )

    _base_doc = LGBMRegressor.__init__.__doc__
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
792
    _base_doc = (
793
794
795
796
797
798
        _before_kwargs
        + 'client : dask.distributed.Client or None, optional (default=None)\n'
        + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
        + ' ' * 8 + _kwargs + _after_kwargs
    )

799
800
801
802
    # the note on custom objective functions in LGBMModel.__init__ is not
    # currently relevant for the Dask estimators
    __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

803
    def __getstate__(self) -> Dict[Any, Any]:
804
        return self._lgb_dask_getstate()
805

806
807
808
809
810
811
812
    def fit(
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
        sample_weight: Optional[_DaskCollection] = None,
        **kwargs: Any
    ) -> "DaskLGBMRegressor":
813
        """Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
814
        return self._lgb_dask_fit(
815
816
817
818
819
820
821
            model_factory=LGBMRegressor,
            X=X,
            y=y,
            sample_weight=sample_weight,
            **kwargs
        )

822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
        sample_weight_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
        group_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)"
    )

    # DaskLGBMRegressor does not support init_score, evaluation data, or early stopping
    _base_doc = (_base_doc[:_base_doc.find('init_score :')]
                 + _base_doc[_base_doc.find('verbose :'):])

    # DaskLGBMRegressor support for callbacks and init_model is not tested
    fit.__doc__ = (
        _base_doc[:_base_doc.find('callbacks :')]
        + '**kwargs\n'
837
        + ' ' * 12 + 'Other parameters passed through to ``LGBMRegressor.fit()``.\n'
838
    )
839

840
    def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
841
        """Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
842
843
844
845
846
847
        return _predict(
            model=self.to_local(),
            data=X,
            **kwargs
        )

848
849
850
851
852
853
854
855
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees]",
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]"
    )
856

857
    def to_local(self) -> LGBMRegressor:
858
859
860
861
862
        """Create regular version of lightgbm.LGBMRegressor from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMRegressor
863
            Local underlying model.
864
        """
865
        return self._lgb_dask_to_local(LGBMRegressor)
866
867


868
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
869
    """Distributed version of lightgbm.LGBMRanker."""
870

871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
    def __init__(
        self,
        boosting_type: str = 'gbdt',
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
        objective: Optional[Union[Callable, str]] = None,
        class_weight: Optional[Union[dict, str]] = None,
        min_split_gain: float = 0.,
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
        subsample: float = 1.,
        subsample_freq: int = 0,
        colsample_bytree: float = 1.,
        reg_alpha: float = 0.,
        reg_lambda: float = 0.,
        random_state: Optional[Union[int, np.random.RandomState]] = None,
        n_jobs: int = -1,
        silent: bool = True,
        importance_type: str = 'split',
        client: Optional[Client] = None,
        **kwargs: Any
    ):
        """Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            silent=silent,
            importance_type=importance_type,
            **kwargs
        )

    _base_doc = LGBMRanker.__init__.__doc__
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs')
924
    _base_doc = (
925
926
927
928
929
930
        _before_kwargs
        + 'client : dask.distributed.Client or None, optional (default=None)\n'
        + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n'
        + ' ' * 8 + _kwargs + _after_kwargs
    )

931
932
933
934
    # the note on custom objective functions in LGBMModel.__init__ is not
    # currently relevant for the Dask estimators
    __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

935
    def __getstate__(self) -> Dict[Any, Any]:
936
        return self._lgb_dask_getstate()
937

938
939
940
941
942
943
944
945
946
    def fit(
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
        sample_weight: Optional[_DaskCollection] = None,
        init_score: Optional[_DaskCollection] = None,
        group: Optional[_DaskCollection] = None,
        **kwargs: Any
    ) -> "DaskLGBMRanker":
947
948
949
950
        """Docstring is inherited from the lightgbm.LGBMRanker.fit."""
        if init_score is not None:
            raise RuntimeError('init_score is not currently supported in lightgbm.dask')

951
        return self._lgb_dask_fit(
952
953
954
955
956
957
958
959
            model_factory=LGBMRanker,
            X=X,
            y=y,
            sample_weight=sample_weight,
            group=group,
            **kwargs
        )

960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
        sample_weight_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)",
        group_shape="Dask Array, Dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)"
    )

    # DaskLGBMRanker does not support init_score, evaluation data, or early stopping
    _base_doc = (_base_doc[:_base_doc.find('init_score :')]
                 + _base_doc[_base_doc.find('init_score :'):])

    _base_doc = (_base_doc[:_base_doc.find('eval_set :')]
                 + _base_doc[_base_doc.find('verbose :'):])

    # DaskLGBMRanker support for callbacks and init_model is not tested
    fit.__doc__ = (
        _base_doc[:_base_doc.find('callbacks :')]
        + '**kwargs\n'
978
        + ' ' * 12 + 'Other parameters passed through to ``LGBMRanker.fit()``.\n'
979
    )
980

981
    def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
982
983
        """Docstring is inherited from the lightgbm.LGBMRanker.predict."""
        return _predict(self.to_local(), X, **kwargs)
984

985
986
987
988
989
990
991
992
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees]",
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]"
    )
993

994
    def to_local(self) -> LGBMRanker:
995
996
997
998
999
        """Create regular version of lightgbm.LGBMRanker from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMRanker
1000
            Local underlying model.
1001
        """
1002
        return self._lgb_dask_to_local(LGBMRanker)