dask.py 65.4 KB
Newer Older
1
# coding: utf-8
2
"""Distributed training with LightGBM and dask.distributed.
3

4
This module enables you to perform distributed training with LightGBM on
5
dask.Array and dask.DataFrame collections.
6
7

It is based on dask-lightgbm, which was based on dask-xgboost.
8
"""
9

10
import operator
11
import socket
12
from collections import defaultdict
13
from copy import deepcopy
14
from enum import Enum, auto
15
from functools import partial
16
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
17
18
19
from urllib.parse import urlparse

import numpy as np
20
21
import scipy.sparse as ss

22
from .basic import LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from .compat import (
    DASK_INSTALLED,
    PANDAS_INSTALLED,
    SKLEARN_INSTALLED,
    Client,
    Future,
    LGBMNotFittedError,
    concat,
    dask_Array,
    dask_array_from_delayed,
    dask_bag_from_delayed,
    dask_DataFrame,
    dask_Series,
    default_client,
    delayed,
    pd_DataFrame,
    pd_Series,
    wait,
)
from .sklearn import (
    LGBMClassifier,
    LGBMModel,
    LGBMRanker,
    LGBMRegressor,
    _LGBM_ScikitCustomObjectiveFunction,
    _LGBM_ScikitEvalMetricType,
    _lgbmmodel_doc_custom_eval_note,
    _lgbmmodel_doc_fit,
    _lgbmmodel_doc_predict,
)
53

54
__all__ = [
55
56
57
    "DaskLGBMClassifier",
    "DaskLGBMRanker",
    "DaskLGBMRegressor",
58
59
]

60
61
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
62
_DaskVectorLike = Union[dask_Array, dask_Series]
63
64
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
65

66

67
68
69
70
class _RemoteSocket:
    def acquire(self) -> int:
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
71
        self.socket.bind(("", 0))
72
        return self.socket.getsockname()[1]
73

74
75
    def release(self) -> None:
        self.socket.close()
76

77
78
79
80
81

def _acquire_port() -> Tuple[_RemoteSocket, int]:
    s = _RemoteSocket()
    port = s.acquire()
    return s, port
82

83

84
85
86
87
88
89
90
91
92
93
94
95
class _DatasetNames(Enum):
    """Placeholder names used by lightgbm.dask internals to say 'also evaluate the training data'.

    Avoid duplicating the training data when the validation set refers to elements of training data.
    """

    TRAINSET = auto()
    SAMPLE_WEIGHT = auto()
    INIT_SCORE = auto()
    GROUP = auto()


96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def _get_dask_client(client: Optional[Client]) -> Client:
    """Choose a Dask client to use.

    Parameters
    ----------
    client : dask.distributed.Client or None
        Dask client.

    Returns
    -------
    client : dask.distributed.Client
        A Dask client.
    """
    if client is None:
        return default_client()
    else:
        return client


115
116
def _assign_open_ports_to_workers(
    client: Client,
117
118
    workers: List[str],
) -> Tuple[Dict[str, Future], Dict[str, int]]:
119
120
121
122
    """Assign an open port to each worker.

    Returns
    -------
123
124
    worker_to_socket_future: dict
        mapping from worker address to a future pointing to the remote socket.
125
    worker_to_port: dict
126
        mapping from worker address to an open port in the worker's host.
127
    """
128
129
130
131
132
133
    # Acquire port in worker
    worker_to_future = {}
    for worker in workers:
        worker_to_future[worker] = client.submit(
            _acquire_port,
            workers=[worker],
134
            allow_other_workers=False,
135
            pure=False,
136
        )
137
138
139
140
141
142
143
144
145
146
147
148

    # schedule futures to retrieve each element of the tuple
    worker_to_socket_future = {}
    worker_to_port_future = {}
    for worker, socket_future in worker_to_future.items():
        worker_to_socket_future[worker] = client.submit(operator.itemgetter(0), socket_future)
        worker_to_port_future[worker] = client.submit(operator.itemgetter(1), socket_future)

    # retrieve ports
    worker_to_port = client.gather(worker_to_port_future)

    return worker_to_socket_future, worker_to_port
149
150


151
def _concat(seq: List[_DaskPart]) -> _DaskPart:
152
153
    if isinstance(seq[0], np.ndarray):
        return np.concatenate(seq, axis=0)
154
    elif isinstance(seq[0], (pd_DataFrame, pd_Series)):
155
        return concat(seq, axis=0)
156
    elif isinstance(seq[0], ss.spmatrix):
157
        return ss.vstack(seq, format="csr")
158
    else:
159
160
161
        raise TypeError(
            f"Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got {type(seq[0]).__name__}."
        )
162
163


164
165
166
167
def _remove_list_padding(*args: Any) -> List[List[Any]]:
    return [[z for z in arg if z is not None] for arg in args]


168
def _pad_eval_names(lgbm_model: LGBMModel, required_names: List[str]) -> LGBMModel:
169
170
171
172
173
174
    """Append missing (key, value) pairs to a LightGBM model's evals_result_ and best_score_ OrderedDict attrs based on a set of required eval_set names.

    Allows users to rely on expected eval_set names being present when fitting DaskLGBM estimators with ``eval_set``.
    """
    for eval_name in required_names:
        if eval_name not in lgbm_model.evals_result_:
175
            lgbm_model.evals_result_[eval_name] = {}
176
        if eval_name not in lgbm_model.best_score_:
177
            lgbm_model.best_score_[eval_name] = {}
178
179
180
181

    return lgbm_model


182
183
184
185
def _train_part(
    params: Dict[str, Any],
    model_factory: Type[LGBMModel],
    list_of_parts: List[Dict[str, _DaskPart]],
186
187
188
    machines: str,
    local_listen_port: int,
    num_machines: int,
189
    return_model: bool,
190
    time_out: int,
191
    remote_socket: _RemoteSocket,
192
    **kwargs: Any,
193
) -> Optional[LGBMModel]:
194
    network_params = {
195
196
197
198
        "machines": machines,
        "local_listen_port": local_listen_port,
        "time_out": time_out,
        "num_machines": num_machines,
199
    }
200
201
    params.update(network_params)

202
203
    is_ranker = issubclass(model_factory, LGBMRanker)

204
    # Concatenate many parts into one
205
206
    data = _concat([x["data"] for x in list_of_parts])
    label = _concat([x["label"] for x in list_of_parts])
207

208
209
    if "weight" in list_of_parts[0]:
        weight = _concat([x["weight"] for x in list_of_parts])
210
211
212
    else:
        weight = None

213
214
    if "group" in list_of_parts[0]:
        group = _concat([x["group"] for x in list_of_parts])
215
216
    else:
        group = None
217

218
219
    if "init_score" in list_of_parts[0]:
        init_score = _concat([x["init_score"] for x in list_of_parts])
220
221
222
    else:
        init_score = None

223
    # construct local eval_set data.
224
225
226
    n_evals = max(len(x.get("eval_set", [])) for x in list_of_parts)
    eval_names = kwargs.pop("eval_names", None)
    eval_class_weight = kwargs.get("eval_class_weight")
227
228
229
230
231
232
233
    local_eval_set = None
    local_eval_names = None
    local_eval_sample_weight = None
    local_eval_init_score = None
    local_eval_group = None

    if n_evals:
234
235
        has_eval_sample_weight = any(x.get("eval_sample_weight") is not None for x in list_of_parts)
        has_eval_init_score = any(x.get("eval_init_score") is not None for x in list_of_parts)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

        local_eval_set = []
        evals_result_names = []
        if has_eval_sample_weight:
            local_eval_sample_weight = []
        if has_eval_init_score:
            local_eval_init_score = []
        if is_ranker:
            local_eval_group = []

        # store indices of eval_set components that were not contained within local parts.
        missing_eval_component_idx = []

        # consolidate parts of each individual eval component.
        for i in range(n_evals):
            x_e = []
            y_e = []
            w_e = []
            init_score_e = []
            g_e = []
            for part in list_of_parts:
257
                if not part.get("eval_set"):
258
259
260
261
262
263
264
                    continue

                # require that eval_name exists in evaluated result data in case dropped due to padding.
                # in distributed training the 'training' eval_set is not detected, will have name 'valid_<index>'.
                if eval_names:
                    evals_result_name = eval_names[i]
                else:
265
                    evals_result_name = f"valid_{i}"
266

267
                eval_set = part["eval_set"][i]
268
                if eval_set is _DatasetNames.TRAINSET:
269
270
                    x_e.append(part["data"])
                    y_e.append(part["label"])
271
272
273
274
275
276
277
                else:
                    x_e.extend(eval_set[0])
                    y_e.extend(eval_set[1])

                if evals_result_name not in evals_result_names:
                    evals_result_names.append(evals_result_name)

278
                eval_weight = part.get("eval_sample_weight")
279
280
                if eval_weight:
                    if eval_weight[i] is _DatasetNames.SAMPLE_WEIGHT:
281
                        w_e.append(part["weight"])
282
283
284
                    else:
                        w_e.extend(eval_weight[i])

285
                eval_init_score = part.get("eval_init_score")
286
287
                if eval_init_score:
                    if eval_init_score[i] is _DatasetNames.INIT_SCORE:
288
                        init_score_e.append(part["init_score"])
289
290
291
                    else:
                        init_score_e.extend(eval_init_score[i])

292
                eval_group = part.get("eval_group")
293
294
                if eval_group:
                    if eval_group[i] is _DatasetNames.GROUP:
295
                        g_e.append(part["group"])
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                    else:
                        g_e.extend(eval_group[i])

            # filter padding from eval parts then _concat each eval_set component.
            x_e, y_e, w_e, init_score_e, g_e = _remove_list_padding(x_e, y_e, w_e, init_score_e, g_e)
            if x_e:
                local_eval_set.append((_concat(x_e), _concat(y_e)))
            else:
                missing_eval_component_idx.append(i)
                continue

            if w_e:
                local_eval_sample_weight.append(_concat(w_e))
            if init_score_e:
                local_eval_init_score.append(_concat(init_score_e))
            if g_e:
                local_eval_group.append(_concat(g_e))

        # reconstruct eval_set fit args/kwargs depending on which components of eval_set are on worker.
        eval_component_idx = [i for i in range(n_evals) if i not in missing_eval_component_idx]
        if eval_names:
            local_eval_names = [eval_names[i] for i in eval_component_idx]
        if eval_class_weight:
319
            kwargs["eval_class_weight"] = [eval_class_weight[i] for i in eval_component_idx]
320

321
    model = model_factory(**params)
322
323
    if remote_socket is not None:
        remote_socket.release()
324
    try:
325
        if is_ranker:
326
327
328
329
330
331
332
333
334
335
336
            model.fit(
                data,
                label,
                sample_weight=weight,
                init_score=init_score,
                group=group,
                eval_set=local_eval_set,
                eval_sample_weight=local_eval_sample_weight,
                eval_init_score=local_eval_init_score,
                eval_group=local_eval_group,
                eval_names=local_eval_names,
337
                **kwargs,
338
            )
339
        else:
340
341
342
343
344
345
346
347
348
            model.fit(
                data,
                label,
                sample_weight=weight,
                init_score=init_score,
                eval_set=local_eval_set,
                eval_sample_weight=local_eval_sample_weight,
                eval_init_score=local_eval_init_score,
                eval_names=local_eval_names,
349
                **kwargs,
350
            )
351

352
    finally:
353
354
        if getattr(model, "fitted_", False):
            model.booster_.free_network()
355

356
357
358
359
    if n_evals:
        # ensure that expected keys for evals_result_ and best_score_ exist regardless of padding.
        model = _pad_eval_names(model, required_names=evals_result_names)

360
361
362
    return model if return_model else None


363
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
364
365
    parts = data.to_delayed()
    if isinstance(parts, np.ndarray):
366
367
368
369
        if is_matrix:
            assert parts.shape[1] == 1
        else:
            assert parts.ndim == 1 or parts.shape[1] == 1
370
371
372
373
        parts = parts.flatten().tolist()
    return parts


374
def _machines_to_worker_map(machines: str, worker_addresses: Iterable[str]) -> Dict[str, int]:
375
376
377
378
379
380
381
382
383
384
    """Create a worker_map from machines list.

    Given ``machines`` and a list of Dask worker addresses, return a mapping where the keys are
    ``worker_addresses`` and the values are ports from ``machines``.

    Parameters
    ----------
    machines : str
        A comma-delimited list of workers, of the form ``ip1:port,ip2:port``.
    worker_addresses : list of str
385
        An iterable of Dask worker addresses, of the form ``{protocol}{hostname}:{port}``, where ``port`` is the port Dask's scheduler uses to talk to that worker.
386
387
388
389
390
391
392

    Returns
    -------
    result : Dict[str, int]
        Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
    """
    machine_addresses = machines.split(",")
393
394

    if len(set(machine_addresses)) != len(machine_addresses):
395
396
397
        raise ValueError(
            f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination."
        )
398

399
400
401
402
403
404
405
406
    machine_to_port = defaultdict(set)
    for address in machine_addresses:
        host, port = address.split(":")
        machine_to_port[host].add(int(port))

    out = {}
    for address in worker_addresses:
        worker_host = urlparse(address).hostname
407
408
        if not worker_host:
            raise ValueError(f"Could not parse host name from worker address '{address}'")
409
410
411
412
413
        out[address] = machine_to_port[worker_host].pop()

    return out


414
415
416
417
418
419
def _train(
    client: Client,
    data: _DaskMatrixLike,
    label: _DaskCollection,
    params: Dict[str, Any],
    model_factory: Type[LGBMModel],
420
    sample_weight: Optional[_DaskVectorLike] = None,
421
    init_score: Optional[_DaskCollection] = None,
422
    group: Optional[_DaskVectorLike] = None,
423
424
    eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
    eval_names: Optional[List[str]] = None,
425
    eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
426
    eval_class_weight: Optional[List[Union[dict, str]]] = None,
427
    eval_init_score: Optional[List[_DaskCollection]] = None,
428
    eval_group: Optional[List[_DaskVectorLike]] = None,
429
    eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
430
    eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None,
431
    **kwargs: Any,
432
) -> LGBMModel:
433
434
435
436
    """Inner train routine.

    Parameters
    ----------
437
438
    client : dask.distributed.Client
        Dask client.
439
    data : Dask Array or Dask DataFrame of shape = [n_samples, n_features]
440
        Input feature matrix.
441
    label : Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]
442
443
        The target values (class labels in classification, real numbers in regression).
    params : dict
444
        Parameters passed to constructor of the local underlying model.
445
    model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
446
        Class of the local underlying model.
447
    sample_weight : Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)
448
        Weights of training data. Weights should be non-negative.
449
    init_score : Dask Array or Dask Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task), or Dask Array or Dask DataFrame of shape = [n_samples, n_classes] (for multi-class task), or None, optional (default=None)
450
        Init score of training data.
451
    group : Dask Array or Dask Series or None, optional (default=None)
452
453
454
455
456
        Group/query data.
        Only used in the learning-to-rank task.
        sum(group) = n_samples.
        For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
        where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
457
    eval_set : list of (X, y) tuples of Dask data collections, or None, optional (default=None)
458
459
460
        List of (X, y) tuple pairs to use as validation sets.
        Note, that not all workers may receive chunks of every eval set within ``eval_set``. When the returned
        lightgbm estimator is not trained using any chunks of a particular eval set, its corresponding component
461
        of ``evals_result_`` and ``best_score_`` will be empty dictionaries.
462
    eval_names : list of str, or None, optional (default=None)
463
        Names of eval_set.
464
    eval_sample_weight : list of Dask Array or Dask Series, or None, optional (default=None)
465
        Weights for each validation set in eval_set. Weights should be non-negative.
466
467
    eval_class_weight : list of dict or str, or None, optional (default=None)
        Class weights, one dict or str for each validation set in eval_set.
468
    eval_init_score : list of Dask Array, Dask Series or Dask DataFrame (for multi-class task), or None, optional (default=None)
469
        Initial model score for each validation set in eval_set.
470
    eval_group : list of Dask Array or Dask Series, or None, optional (default=None)
471
        Group/query for each validation set in eval_set.
472
473
    eval_metric : str, callable, list or None, optional (default=None)
        If str, it should be a built-in evaluation metric to use.
474
475
476
477
        If callable, it should be a custom evaluation metric, see note below for more details.
        If list, it can be a list of built-in metrics, a list of custom evaluation metrics, or a mix of both.
        In either case, the ``metric`` from the Dask model parameters (or inferred from the objective) will be evaluated and used as well.
        Default: 'l2' for DaskLGBMRegressor, 'binary(multi)_logloss' for DaskLGBMClassifier, 'ndcg' for DaskLGBMRanker.
478
    eval_at : list or tuple of int, optional (default=None)
479
        The evaluation positions of the specified ranking metric.
480
481
482
483
484
485
486
    **kwargs
        Other parameters passed to ``fit`` method of the local underlying model.

    Returns
    -------
    model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
        Returns fitted underlying model.
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515

    Note
    ----

    This method handles setting up the following network parameters based on information
    about the Dask cluster referenced by ``client``.

    * ``local_listen_port``: port that each LightGBM worker opens a listening socket on,
            to accept connections from other workers. This can differ from LightGBM worker
            to LightGBM worker, but does not have to.
    * ``machines``: a comma-delimited list of all workers in the cluster, in the
            form ``ip:port,ip:port``. If running multiple Dask workers on the same host, use different
            ports for each worker. For example, for ``LocalCluster(n_workers=3)``, you might
            pass ``"127.0.0.1:12400,127.0.0.1:12401,127.0.0.1:12402"``.
    * ``num_machines``: number of LightGBM workers.
    * ``timeout``: time in minutes to wait before closing unused sockets.

    The default behavior of this function is to generate ``machines`` from the list of
    Dask workers which hold some piece of the training data, and to search for an open
    port on each worker to be used as ``local_listen_port``.

    If ``machines`` is provided explicitly in ``params``, this function uses the hosts
    and ports in that list directly, and does not do any searching. This means that if
    any of the Dask workers are missing from the list or any of those ports are not free
    when training starts, training will fail.

    If ``local_listen_port`` is provided in ``params`` and ``machines`` is not, this function
    constructs ``machines`` from the list of Dask workers which hold some piece of the
    training data, assuming that each one will use the same ``local_listen_port``.
516
    """
517
518
    params = deepcopy(params)

519
    # capture whether local_listen_port or its aliases were provided
520
    listen_port_in_params = any(alias in params for alias in _ConfigAliases.get("local_listen_port"))
521
522

    # capture whether machines or its aliases were provided
523
    machines_in_params = any(alias in params for alias in _ConfigAliases.get("machines"))
524
525
526
527

    params = _choose_param_value(
        main_param_name="tree_learner",
        params=params,
528
        default_value="data",
529
530
    )
    allowed_tree_learners = {
531
532
533
534
535
536
        "data",
        "data_parallel",
        "feature",
        "feature_parallel",
        "voting",
        "voting_parallel",
537
538
    }
    if params["tree_learner"] not in allowed_tree_learners:
539
540
541
542
        _log_warning(
            f'Parameter tree_learner set to {params["tree_learner"]}, which is not allowed. Using "data" as default'
        )
        params["tree_learner"] = "data"
543
544
545
546

    # Some passed-in parameters can be removed:
    #   * 'num_machines': set automatically from Dask worker list
    #   * 'num_threads': overridden to match nthreads on each Dask process
547
    for param_alias in _ConfigAliases.get("num_machines", "num_threads"):
548
549
550
        if param_alias in params:
            _log_warning(f"Parameter {param_alias} will be ignored.")
            params.pop(param_alias)
551

552
    # Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
553
554
    data_parts = _split_to_parts(data=data, is_matrix=True)
    label_parts = _split_to_parts(data=label, is_matrix=False)
555
    parts = [{"data": x, "label": y} for (x, y) in zip(data_parts, label_parts)]
556
    n_parts = len(parts)
557
558
559

    if sample_weight is not None:
        weight_parts = _split_to_parts(data=sample_weight, is_matrix=False)
560
        for i in range(n_parts):
561
            parts[i]["weight"] = weight_parts[i]
562
563
564

    if group is not None:
        group_parts = _split_to_parts(data=group, is_matrix=False)
565
        for i in range(n_parts):
566
            parts[i]["group"] = group_parts[i]
567

568
569
570
    if init_score is not None:
        init_score_parts = _split_to_parts(data=init_score, is_matrix=False)
        for i in range(n_parts):
571
            parts[i]["init_score"] = init_score_parts[i]
572

573
574
575
576
577
578
579
    # evals_set will to be re-constructed into smaller lists of (X, y) tuples, where
    # X and y are each delayed sub-lists of original eval dask Collections.
    if eval_set:
        # find maximum number of parts in an individual eval set so that we can
        # pad eval sets when they come in different sizes.
        n_largest_eval_parts = max(x[0].npartitions for x in eval_set)

580
        eval_sets: Dict[
581
            int, List[Union[_DatasetNames, Tuple[List[Optional[_DaskMatrixLike]], List[Optional[_DaskVectorLike]]]]]
582
        ] = defaultdict(list)
583
        if eval_sample_weight:
584
585
586
            eval_sample_weights: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskVectorLike]]]]] = defaultdict(
                list
            )
587
        if eval_group:
588
            eval_groups: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskVectorLike]]]]] = defaultdict(list)
589
        if eval_init_score:
590
            eval_init_scores: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskMatrixLike]]]]] = defaultdict(list)
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617

        for i, (X_eval, y_eval) in enumerate(eval_set):
            n_this_eval_parts = X_eval.npartitions

            # when individual eval set is equivalent to training data, skip recomputing parts.
            if X_eval is data and y_eval is label:
                for parts_idx in range(n_parts):
                    eval_sets[parts_idx].append(_DatasetNames.TRAINSET)
            else:
                eval_x_parts = _split_to_parts(data=X_eval, is_matrix=True)
                eval_y_parts = _split_to_parts(data=y_eval, is_matrix=False)
                for j in range(n_largest_eval_parts):
                    parts_idx = j % n_parts

                    # add None-padding for individual eval_set member if it is smaller than the largest member.
                    if j < n_this_eval_parts:
                        x_e = eval_x_parts[j]
                        y_e = eval_y_parts[j]
                    else:
                        x_e = None
                        y_e = None

                    if j < n_parts:
                        # first time a chunk of this eval set is added to this part.
                        eval_sets[parts_idx].append(([x_e], [y_e]))
                    else:
                        # append additional chunks of this eval set to this part.
618
619
                        eval_sets[parts_idx][-1][0].append(x_e)  # type: ignore[index, union-attr]
                        eval_sets[parts_idx][-1][1].append(y_e)  # type: ignore[index, union-attr]
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638

            if eval_sample_weight:
                if eval_sample_weight[i] is sample_weight:
                    for parts_idx in range(n_parts):
                        eval_sample_weights[parts_idx].append(_DatasetNames.SAMPLE_WEIGHT)
                else:
                    eval_w_parts = _split_to_parts(data=eval_sample_weight[i], is_matrix=False)

                    # ensure that all evaluation parts map uniquely to one part.
                    for j in range(n_largest_eval_parts):
                        if j < n_this_eval_parts:
                            w_e = eval_w_parts[j]
                        else:
                            w_e = None

                        parts_idx = j % n_parts
                        if j < n_parts:
                            eval_sample_weights[parts_idx].append([w_e])
                        else:
639
                            eval_sample_weights[parts_idx][-1].append(w_e)  # type: ignore[union-attr]
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656

            if eval_init_score:
                if eval_init_score[i] is init_score:
                    for parts_idx in range(n_parts):
                        eval_init_scores[parts_idx].append(_DatasetNames.INIT_SCORE)
                else:
                    eval_init_score_parts = _split_to_parts(data=eval_init_score[i], is_matrix=False)
                    for j in range(n_largest_eval_parts):
                        if j < n_this_eval_parts:
                            init_score_e = eval_init_score_parts[j]
                        else:
                            init_score_e = None

                        parts_idx = j % n_parts
                        if j < n_parts:
                            eval_init_scores[parts_idx].append([init_score_e])
                        else:
657
                            eval_init_scores[parts_idx][-1].append(init_score_e)  # type: ignore[union-attr]
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

            if eval_group:
                if eval_group[i] is group:
                    for parts_idx in range(n_parts):
                        eval_groups[parts_idx].append(_DatasetNames.GROUP)
                else:
                    eval_g_parts = _split_to_parts(data=eval_group[i], is_matrix=False)
                    for j in range(n_largest_eval_parts):
                        if j < n_this_eval_parts:
                            g_e = eval_g_parts[j]
                        else:
                            g_e = None

                        parts_idx = j % n_parts
                        if j < n_parts:
                            eval_groups[parts_idx].append([g_e])
                        else:
675
                            eval_groups[parts_idx][-1].append(g_e)  # type: ignore[union-attr]
676
677
678

        # assign sub-eval_set components to worker parts.
        for parts_idx, e_set in eval_sets.items():
679
            parts[parts_idx]["eval_set"] = e_set
680
            if eval_sample_weight:
681
                parts[parts_idx]["eval_sample_weight"] = eval_sample_weights[parts_idx]
682
            if eval_init_score:
683
                parts[parts_idx]["eval_init_score"] = eval_init_scores[parts_idx]
684
            if eval_group:
685
                parts[parts_idx]["eval_group"] = eval_groups[parts_idx]
686

687
    # Start computation in the background
688
    parts = list(map(delayed, parts))
689
690
691
692
    parts = client.compute(parts)
    wait(parts)

    for part in parts:
693
        if part.status == "error":  # type: ignore
694
695
            # trigger error locally
            return part  # type: ignore[return-value]
696
697

    # Find locations of all parts and map them to particular Dask workers
698
    key_to_part_dict = {part.key: part for part in parts}  # type: ignore
699
700
701
702
703
    who_has = client.who_has(parts)
    worker_map = defaultdict(list)
    for key, workers in who_has.items():
        worker_map[next(iter(workers))].append(key_to_part_dict[key])

704
705
706
707
708
709
    # Check that all workers were provided some of eval_set. Otherwise warn user that validation
    # data artifacts may not be populated depending on worker returning final estimator.
    if eval_set:
        for worker in worker_map:
            has_eval_set = False
            for part in worker_map[worker]:
710
                if "eval_set" in part.result():  # type: ignore[attr-defined]
711
712
713
714
715
716
717
718
719
720
721
                    has_eval_set = True
                    break

            if not has_eval_set:
                _log_warning(
                    f"Worker {worker} was not allocated eval_set data. Therefore evals_result_ and best_score_ data may be unreliable. "
                    "Try rebalancing data across workers."
                )

    # assign general validation set settings to fit kwargs.
    if eval_names:
722
        kwargs["eval_names"] = eval_names
723
    if eval_class_weight:
724
        kwargs["eval_class_weight"] = eval_class_weight
725
    if eval_metric:
726
        kwargs["eval_metric"] = eval_metric
727
    if eval_at:
728
        kwargs["eval_at"] = eval_at
729

730
731
732
    master_worker = next(iter(worker_map))
    worker_ncores = client.ncores()

733
734
735
736
737
    # resolve aliases for network parameters and pop the result off params.
    # these values are added back in calls to `_train_part()`
    params = _choose_param_value(
        main_param_name="local_listen_port",
        params=params,
738
        default_value=12400,
739
    )
740
741
742
743
744
    local_listen_port = params.pop("local_listen_port")

    params = _choose_param_value(
        main_param_name="machines",
        params=params,
745
        default_value=None,
746
747
748
749
    )
    machines = params.pop("machines")

    # figure out network params
750
    worker_to_socket_future: Dict[str, Future] = {}
751
752
753
754
755
    worker_addresses = worker_map.keys()
    if machines is not None:
        _log_info("Using passed-in 'machines' parameter")
        worker_address_to_port = _machines_to_worker_map(
            machines=machines,
756
            worker_addresses=worker_addresses,
757
758
759
760
        )
    else:
        if listen_port_in_params:
            _log_info("Using passed-in 'local_listen_port' for all workers")
761
            unique_hosts = {urlparse(a).hostname for a in worker_addresses}
762
763
764
765
766
767
768
769
            if len(unique_hosts) < len(worker_addresses):
                msg = (
                    "'local_listen_port' was provided in Dask training parameters, but at least one "
                    "machine in the cluster has multiple Dask worker processes running on it. Please omit "
                    "'local_listen_port' or pass 'machines'."
                )
                raise LightGBMError(msg)

770
            worker_address_to_port = {address: local_listen_port for address in worker_addresses}
771
772
        else:
            _log_info("Finding random open ports for workers")
773
774
775
            worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(
                client, list(worker_map.keys())
            )
776

777
778
779
        machines = ",".join(
            [f"{urlparse(worker_address).hostname}:{port}" for worker_address, port in worker_address_to_port.items()]
        )
780
781

    num_machines = len(worker_address_to_port)
782

783
    # Tell each worker to train on the parts that it has locally
784
    #
785
    # This code treats ``_train_part()`` calls as not "pure" because:
786
    #     1. there is randomness in the training process unless parameters ``seed``
787
    #        and ``deterministic`` are set
788
789
790
    #     2. even with those parameters set, the output of one ``_train_part()`` call
    #        relies on global state (it and all the other LightGBM training processes
    #        coordinate with each other)
791
792
793
794
    futures_classifiers = [
        client.submit(
            _train_part,
            model_factory=model_factory,
795
            params={**params, "num_threads": worker_ncores[worker]},
796
            list_of_parts=list_of_parts,
797
798
799
            machines=machines,
            local_listen_port=worker_address_to_port[worker],
            num_machines=num_machines,
800
            time_out=params.get("time_out", 120),
801
            remote_socket=worker_to_socket_future.get(worker, None),
802
            return_model=(worker == master_worker),
803
804
805
            workers=[worker],
            allow_other_workers=False,
            pure=False,
806
            **kwargs,
807
808
809
        )
        for worker, list_of_parts in worker_map.items()
    ]
810
811
812

    results = client.gather(futures_classifiers)
    results = [v for v in results if v]
813
814
815
    model = results[0]

    # if network parameters were changed during training, remove them from the
Andrew Ziem's avatar
Andrew Ziem committed
816
    # returned model so that they're generated dynamically on every run based
817
818
819
    # on the Dask cluster you're connected to and which workers have pieces of
    # the training data
    if not listen_port_in_params:
820
        for param in _ConfigAliases.get("local_listen_port"):
821
822
823
            model._other_params.pop(param, None)

    if not machines_in_params:
824
        for param in _ConfigAliases.get("machines"):
825
826
            model._other_params.pop(param, None)

827
    for param in _ConfigAliases.get("num_machines", "timeout"):
828
829
830
        model._other_params.pop(param, None)

    return model
831
832


833
834
835
836
837
838
839
def _predict_part(
    part: _DaskPart,
    model: LGBMModel,
    raw_score: bool,
    pred_proba: bool,
    pred_leaf: bool,
    pred_contrib: bool,
840
    **kwargs: Any,
841
) -> _DaskPart:
842
    result: _DaskPart
843
    if part.shape[0] == 0:
844
        result = np.array([])
845
846
    elif pred_proba:
        result = model.predict_proba(
847
            part,
848
849
850
            raw_score=raw_score,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
851
            **kwargs,
852
        )
853
    else:
854
        result = model.predict(
855
            part,
856
857
858
            raw_score=raw_score,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
859
            **kwargs,
860
        )
861

862
    # dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series
863
    if isinstance(part, pd_DataFrame):
864
        if len(result.shape) == 2:
865
            result = pd_DataFrame(result, index=part.index)
866
        else:
867
            result = pd_Series(result, index=part.index, name="predictions")
868
869
870
871

    return result


872
873
874
def _predict(
    model: LGBMModel,
    data: _DaskMatrixLike,
875
    client: Client,
876
877
878
879
880
    raw_score: bool = False,
    pred_proba: bool = False,
    pred_leaf: bool = False,
    pred_contrib: bool = False,
    dtype: _PredictionDtype = np.float32,
881
    **kwargs: Any,
882
) -> Union[dask_Array, List[dask_Array]]:
883
884
885
886
    """Inner predict routine.

    Parameters
    ----------
887
    model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
888
        Fitted underlying model.
889
    data : Dask Array or Dask DataFrame of shape = [n_samples, n_features]
890
        Input feature matrix.
891
892
    raw_score : bool, optional (default=False)
        Whether to predict raw scores.
893
894
895
896
897
898
    pred_proba : bool, optional (default=False)
        Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``).
    pred_leaf : bool, optional (default=False)
        Whether to predict leaf index.
    pred_contrib : bool, optional (default=False)
        Whether to predict feature contributions.
899
    dtype : np.dtype, optional (default=np.float32)
900
        Dtype of the output.
901
    **kwargs
902
        Other parameters passed to ``predict`` or ``predict_proba`` method.
903
904
905

    Returns
    -------
906
    predicted_result : Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]
907
        The predicted values.
908
    X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
909
        If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
910
    X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]
911
        If ``pred_contrib=True``, the feature contributions for each sample.
912
    """
913
    if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
914
        raise LightGBMError("dask, pandas and scikit-learn are required for lightgbm.dask")
915
    if isinstance(data, dask_DataFrame):
916
917
918
919
920
921
922
        return data.map_partitions(
            _predict_part,
            model=model,
            raw_score=raw_score,
            pred_proba=pred_proba,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
923
            **kwargs,
924
        ).values
925
    elif isinstance(data, dask_Array):
926
927
        # for multi-class classification with sparse matrices, pred_contrib predictions
        # are returned as a list of sparse matrices (one per class)
928
        num_classes = model._n_classes
929

930
        if num_classes > 2 and pred_contrib and isinstance(data._meta, ss.spmatrix):
931
932
933
934
935
936
937
            predict_function = partial(
                _predict_part,
                model=model,
                raw_score=False,
                pred_proba=pred_proba,
                pred_leaf=False,
                pred_contrib=True,
938
                **kwargs,
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
            )

            delayed_chunks = data.to_delayed()
            bag = dask_bag_from_delayed(delayed_chunks[:, 0])

            @delayed
            def _extract(items: List[Any], i: int) -> Any:
                return items[i]

            preds = bag.map_partitions(predict_function)

            # pred_contrib output will have one column per feature,
            # plus one more for the base value
            num_cols = model.n_features_ + 1

            nrows_per_chunk = data.chunks[0]
955
            out: List[List[dask_Array]] = [[] for _ in range(num_classes)]
956
957
958
959
960
961
962
963
964

            # need to tell Dask the expected type and shape of individual preds
            pred_meta = data._meta

            for j, partition in enumerate(preds.to_delayed()):
                for i in range(num_classes):
                    part = dask_array_from_delayed(
                        value=_extract(partition, i),
                        shape=(nrows_per_chunk[j], num_cols),
965
                        meta=pred_meta,
966
967
968
969
                    )
                    out[i].append(part)

            # by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix
970
            # the code below is used instead to ensure that the sparse type is preserved during concatenation
971
            if isinstance(pred_meta, ss.csr_matrix):
972
                concat_fn = partial(ss.vstack, format="csr")
973
            elif isinstance(pred_meta, ss.csc_matrix):
974
                concat_fn = partial(ss.vstack, format="csc")
975
976
977
978
979
            else:
                concat_fn = ss.vstack

            # At this point, `out` is a list of lists of delayeds (each of which points to a matrix).
            # Concatenate them to return a list of Dask Arrays.
980
            out_arrays: List[dask_Array] = []
981
            for i in range(num_classes):
982
983
984
985
                out_arrays.append(
                    dask_array_from_delayed(
                        value=delayed(concat_fn)(out[i]),
                        shape=(data.shape[0], num_cols),
986
                        meta=pred_meta,
987
                    )
988
989
                )

990
            return out_arrays
991

992
993
        data_row = client.compute(data[[0]]).result()
        predict_fn = partial(
994
995
996
997
998
999
            _predict_part,
            model=model,
            raw_score=raw_score,
            pred_proba=pred_proba,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
1000
1001
1002
            **kwargs,
        )
        pred_row = predict_fn(data_row)
1003
        chunks: Tuple[int, ...] = (data.chunks[0],)
1004
1005
1006
1007
        map_blocks_kwargs = {}
        if len(pred_row.shape) > 1:
            chunks += (pred_row.shape[1],)
        else:
1008
            map_blocks_kwargs["drop_axis"] = 1
1009
1010
1011
1012
        return data.map_blocks(
            predict_fn,
            chunks=chunks,
            meta=pred_row,
1013
            dtype=dtype,
1014
            **map_blocks_kwargs,
1015
        )
1016
    else:
1017
        raise TypeError(f"Data must be either Dask Array or Dask DataFrame. Got {type(data).__name__}.")
1018
1019


1020
class _DaskLGBMModel:
1021
1022
    @property
    def client_(self) -> Client:
1023
        """:obj:`dask.distributed.Client`: Dask client.
1024
1025
1026
1027
1028

        This property can be passed in the constructor or updated
        with ``model.set_params(client=client)``.
        """
        if not getattr(self, "fitted_", False):
1029
            raise LGBMNotFittedError("Cannot access property client_ before calling fit().")
1030
1031
1032

        return _get_dask_client(client=self.client)

1033
    def _lgb_dask_getstate(self) -> Dict[Any, Any]:
1034
1035
        """Remove un-picklable attributes before serialization."""
        client = self.__dict__.pop("client", None)
1036
        self._other_params.pop("client", None)  # type: ignore[attr-defined]
1037
        out = deepcopy(self.__dict__)
1038
        out.update({"client": None})
1039
1040
1041
        self.client = client
        return out

1042
    def _lgb_dask_fit(
1043
1044
1045
1046
        self,
        model_factory: Type[LGBMModel],
        X: _DaskMatrixLike,
        y: _DaskCollection,
1047
        sample_weight: Optional[_DaskVectorLike] = None,
1048
        init_score: Optional[_DaskCollection] = None,
1049
        group: Optional[_DaskVectorLike] = None,
1050
1051
        eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
        eval_names: Optional[List[str]] = None,
1052
        eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
1053
        eval_class_weight: Optional[List[Union[dict, str]]] = None,
1054
        eval_init_score: Optional[List[_DaskCollection]] = None,
1055
        eval_group: Optional[List[_DaskVectorLike]] = None,
1056
        eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
1057
        eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None,
1058
        **kwargs: Any,
1059
    ) -> "_DaskLGBMModel":
1060
        if not DASK_INSTALLED:
1061
            raise LightGBMError("dask is required for lightgbm.dask")
1062
        if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
1063
            raise LightGBMError("dask, pandas and scikit-learn are required for lightgbm.dask")
1064

1065
        params = self.get_params(True)  # type: ignore[attr-defined]
1066
        params.pop("client", None)
1067
1068

        model = _train(
1069
            client=_get_dask_client(self.client),
1070
1071
1072
1073
1074
            data=X,
            label=y,
            params=params,
            model_factory=model_factory,
            sample_weight=sample_weight,
1075
            init_score=init_score,
1076
            group=group,
1077
1078
1079
1080
1081
1082
1083
1084
            eval_set=eval_set,
            eval_names=eval_names,
            eval_sample_weight=eval_sample_weight,
            eval_class_weight=eval_class_weight,
            eval_init_score=eval_init_score,
            eval_group=eval_group,
            eval_metric=eval_metric,
            eval_at=eval_at,
1085
            **kwargs,
1086
        )
1087

1088
1089
        self.set_params(**model.get_params())  # type: ignore[attr-defined]
        self._lgb_dask_copy_extra_params(model, self)  # type: ignore[attr-defined]
1090
1091
1092

        return self

1093
    def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
1094
        params = self.get_params()  # type: ignore[attr-defined]
1095
1096
        params.pop("client", None)
        model = model_factory(**params)
1097
        self._lgb_dask_copy_extra_params(self, model)
1098
        model._other_params.pop("client", None)
1099
1100
1101
        return model

    @staticmethod
1102
1103
1104
1105
    def _lgb_dask_copy_extra_params(
        source: Union["_DaskLGBMModel", LGBMModel],
        dest: Union["_DaskLGBMModel", LGBMModel],
    ) -> None:
1106
        params = source.get_params()  # type: ignore[union-attr]
1107
1108
1109
        attributes = source.__dict__
        extra_param_names = set(attributes.keys()).difference(params.keys())
        for name in extra_param_names:
1110
            setattr(dest, name, attributes[name])
1111
1112


1113
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
1114
1115
    """Distributed version of lightgbm.LGBMClassifier."""

1116
1117
    def __init__(
        self,
1118
        boosting_type: str = "gbdt",
1119
1120
1121
1122
1123
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
1124
        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1125
        class_weight: Optional[Union[dict, str]] = None,
1126
        min_split_gain: float = 0.0,
1127
1128
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
1129
        subsample: float = 1.0,
1130
        subsample_freq: int = 0,
1131
1132
1133
1134
        colsample_bytree: float = 1.0,
        reg_alpha: float = 0.0,
        reg_lambda: float = 0.0,
        random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
1135
        n_jobs: Optional[int] = None,
1136
        importance_type: str = "split",
1137
        client: Optional[Client] = None,
1138
        **kwargs: Any,
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
    ):
        """Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            importance_type=importance_type,
1162
            **kwargs,
1163
1164
1165
        )

    _base_doc = LGBMClassifier.__init__.__doc__
1166
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs")  # type: ignore
1167
    __init__.__doc__ = f"""
1168
        {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
1169
        {" ":4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
1170
1171
        {_kwargs}{_after_kwargs}
        """
1172
1173

    def __getstate__(self) -> Dict[Any, Any]:
1174
        return self._lgb_dask_getstate()
1175

1176
    def fit(  # type: ignore[override]
1177
1178
1179
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
1180
        sample_weight: Optional[_DaskVectorLike] = None,
1181
        init_score: Optional[_DaskCollection] = None,
1182
1183
        eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
        eval_names: Optional[List[str]] = None,
1184
        eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
1185
        eval_class_weight: Optional[List[Union[dict, str]]] = None,
1186
        eval_init_score: Optional[List[_DaskCollection]] = None,
1187
        eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
1188
        **kwargs: Any,
1189
    ) -> "DaskLGBMClassifier":
1190
        """Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
1191
        self._lgb_dask_fit(
1192
1193
1194
1195
            model_factory=LGBMClassifier,
            X=X,
            y=y,
            sample_weight=sample_weight,
1196
            init_score=init_score,
1197
1198
1199
1200
1201
1202
            eval_set=eval_set,
            eval_names=eval_names,
            eval_sample_weight=eval_sample_weight,
            eval_class_weight=eval_class_weight,
            eval_init_score=eval_init_score,
            eval_metric=eval_metric,
1203
            **kwargs,
1204
        )
1205
        return self
1206

1207
1208
1209
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
1210
        sample_weight_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
1211
        init_score_shape="Dask Array or Dask Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task), or Dask Array or Dask DataFrame of shape = [n_samples, n_classes] (for multi-class task), or None, optional (default=None)",
1212
        group_shape="Dask Array or Dask Series or None, optional (default=None)",
1213
        eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1214
        eval_init_score_shape="list of Dask Array, Dask Series or Dask DataFrame (for multi-class task), or None, optional (default=None)",
1215
        eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1216
1217
    )

1218
    # DaskLGBMClassifier does not support group, eval_group.
1219
    _base_doc = _base_doc[: _base_doc.find("group :")] + _base_doc[_base_doc.find("eval_set :") :]
1220

1221
    _base_doc = _base_doc[: _base_doc.find("eval_group :")] + _base_doc[_base_doc.find("eval_metric :") :]
1222

1223
    # DaskLGBMClassifier support for callbacks and init_model is not tested
1224
    fit.__doc__ = f"""{_base_doc[: _base_doc.find("callbacks :")]}**kwargs
1225
        Other parameters passed through to ``LGBMClassifier.fit()``.
1226

1227
1228
1229
1230
1231
    Returns
    -------
    self : lightgbm.DaskLGBMClassifier
        Returns self.

1232
    {_lgbmmodel_doc_custom_eval_note}
1233
        """
1234

1235
1236
    def predict(
        self,
1237
        X: _DaskMatrixLike,  # type: ignore[override]
1238
1239
1240
1241
1242
1243
        raw_score: bool = False,
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        validate_features: bool = False,
1244
        **kwargs: Any,
1245
    ) -> dask_Array:
1246
        """Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
1247
1248
1249
1250
        return _predict(
            model=self.to_local(),
            data=X,
            dtype=self.classes_.dtype,
1251
            client=_get_dask_client(self.client),
1252
1253
1254
1255
1256
1257
            raw_score=raw_score,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            validate_features=validate_features,
1258
            **kwargs,
1259
1260
        )

1261
1262
1263
1264
1265
1266
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
1267
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]",
1268
    )
1269

1270
1271
    def predict_proba(
        self,
1272
        X: _DaskMatrixLike,  # type: ignore[override]
1273
1274
1275
1276
1277
1278
        raw_score: bool = False,
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        validate_features: bool = False,
1279
        **kwargs: Any,
1280
    ) -> dask_Array:
1281
        """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
1282
1283
1284
1285
        return _predict(
            model=self.to_local(),
            data=X,
            pred_proba=True,
1286
            client=_get_dask_client(self.client),
1287
1288
1289
1290
1291
1292
            raw_score=raw_score,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            validate_features=validate_features,
1293
            **kwargs,
1294
1295
        )

1296
1297
1298
1299
    predict_proba.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted probability for each class for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_probability",
1300
        predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
1301
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
1302
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]",
1303
    )
1304

1305
    def to_local(self) -> LGBMClassifier:
1306
1307
1308
1309
1310
        """Create regular version of lightgbm.LGBMClassifier from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMClassifier
1311
            Local underlying model.
1312
        """
1313
        return self._lgb_dask_to_local(LGBMClassifier)
1314
1315


1316
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
1317
    """Distributed version of lightgbm.LGBMRegressor."""
1318

1319
1320
    def __init__(
        self,
1321
        boosting_type: str = "gbdt",
1322
1323
1324
1325
1326
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
1327
        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1328
        class_weight: Optional[Union[dict, str]] = None,
1329
        min_split_gain: float = 0.0,
1330
1331
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
1332
        subsample: float = 1.0,
1333
        subsample_freq: int = 0,
1334
1335
1336
1337
        colsample_bytree: float = 1.0,
        reg_alpha: float = 0.0,
        reg_lambda: float = 0.0,
        random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
1338
        n_jobs: Optional[int] = None,
1339
        importance_type: str = "split",
1340
        client: Optional[Client] = None,
1341
        **kwargs: Any,
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
    ):
        """Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            importance_type=importance_type,
1365
            **kwargs,
1366
1367
1368
        )

    _base_doc = LGBMRegressor.__init__.__doc__
1369
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs")  # type: ignore
1370
    __init__.__doc__ = f"""
1371
        {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
1372
        {" ":4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
1373
1374
        {_kwargs}{_after_kwargs}
        """
1375

1376
    def __getstate__(self) -> Dict[Any, Any]:
1377
        return self._lgb_dask_getstate()
1378

1379
    def fit(  # type: ignore[override]
1380
1381
1382
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
1383
1384
        sample_weight: Optional[_DaskVectorLike] = None,
        init_score: Optional[_DaskVectorLike] = None,
1385
1386
        eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
        eval_names: Optional[List[str]] = None,
1387
1388
        eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
        eval_init_score: Optional[List[_DaskVectorLike]] = None,
1389
        eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
1390
        **kwargs: Any,
1391
    ) -> "DaskLGBMRegressor":
1392
        """Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
1393
        self._lgb_dask_fit(
1394
1395
1396
1397
            model_factory=LGBMRegressor,
            X=X,
            y=y,
            sample_weight=sample_weight,
1398
            init_score=init_score,
1399
1400
1401
1402
1403
            eval_set=eval_set,
            eval_names=eval_names,
            eval_sample_weight=eval_sample_weight,
            eval_init_score=eval_init_score,
            eval_metric=eval_metric,
1404
            **kwargs,
1405
        )
1406
        return self
1407

1408
1409
1410
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
1411
1412
        sample_weight_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
        init_score_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
1413
        group_shape="Dask Array or Dask Series or None, optional (default=None)",
1414
1415
        eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
        eval_init_score_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1416
        eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1417
1418
    )

1419
    # DaskLGBMRegressor does not support group, eval_class_weight, eval_group.
1420
    _base_doc = _base_doc[: _base_doc.find("group :")] + _base_doc[_base_doc.find("eval_set :") :]
1421

1422
    _base_doc = _base_doc[: _base_doc.find("eval_class_weight :")] + _base_doc[_base_doc.find("eval_init_score :") :]
1423

1424
    _base_doc = _base_doc[: _base_doc.find("eval_group :")] + _base_doc[_base_doc.find("eval_metric :") :]
1425

1426
    # DaskLGBMRegressor support for callbacks and init_model is not tested
1427
    fit.__doc__ = f"""{_base_doc[: _base_doc.find("callbacks :")]}**kwargs
1428
        Other parameters passed through to ``LGBMRegressor.fit()``.
1429

1430
1431
1432
1433
1434
    Returns
    -------
    self : lightgbm.DaskLGBMRegressor
        Returns self.

1435
    {_lgbmmodel_doc_custom_eval_note}
1436
        """
1437

1438
1439
    def predict(
        self,
1440
        X: _DaskMatrixLike,  # type: ignore[override]
1441
1442
1443
1444
1445
1446
        raw_score: bool = False,
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        validate_features: bool = False,
1447
        **kwargs: Any,
1448
    ) -> dask_Array:
1449
        """Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
1450
1451
1452
        return _predict(
            model=self.to_local(),
            data=X,
1453
            client=_get_dask_client(self.client),
1454
1455
1456
1457
1458
1459
            raw_score=raw_score,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            validate_features=validate_features,
1460
            **kwargs,
1461
1462
        )

1463
1464
1465
1466
1467
1468
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees]",
1469
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]",
1470
    )
1471

1472
    def to_local(self) -> LGBMRegressor:
1473
1474
1475
1476
1477
        """Create regular version of lightgbm.LGBMRegressor from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMRegressor
1478
            Local underlying model.
1479
        """
1480
        return self._lgb_dask_to_local(LGBMRegressor)
1481
1482


1483
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
1484
    """Distributed version of lightgbm.LGBMRanker."""
1485

1486
1487
    def __init__(
        self,
1488
        boosting_type: str = "gbdt",
1489
1490
1491
1492
1493
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
1494
        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1495
        class_weight: Optional[Union[dict, str]] = None,
1496
        min_split_gain: float = 0.0,
1497
1498
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
1499
        subsample: float = 1.0,
1500
        subsample_freq: int = 0,
1501
1502
1503
1504
        colsample_bytree: float = 1.0,
        reg_alpha: float = 0.0,
        reg_lambda: float = 0.0,
        random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
1505
        n_jobs: Optional[int] = None,
1506
        importance_type: str = "split",
1507
        client: Optional[Client] = None,
1508
        **kwargs: Any,
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
    ):
        """Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            importance_type=importance_type,
1532
            **kwargs,
1533
1534
1535
        )

    _base_doc = LGBMRanker.__init__.__doc__
1536
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs")  # type: ignore
1537
    __init__.__doc__ = f"""
1538
        {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
1539
        {" ":4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
1540
1541
        {_kwargs}{_after_kwargs}
        """
1542
1543

    def __getstate__(self) -> Dict[Any, Any]:
1544
        return self._lgb_dask_getstate()
1545

1546
    def fit(  # type: ignore[override]
1547
1548
1549
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
1550
1551
1552
        sample_weight: Optional[_DaskVectorLike] = None,
        init_score: Optional[_DaskVectorLike] = None,
        group: Optional[_DaskVectorLike] = None,
1553
1554
        eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
        eval_names: Optional[List[str]] = None,
1555
1556
1557
        eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
        eval_init_score: Optional[List[_DaskVectorLike]] = None,
        eval_group: Optional[List[_DaskVectorLike]] = None,
1558
        eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
1559
        eval_at: Union[List[int], Tuple[int, ...]] = (1, 2, 3, 4, 5),
1560
        **kwargs: Any,
1561
    ) -> "DaskLGBMRanker":
1562
        """Docstring is inherited from the lightgbm.LGBMRanker.fit."""
1563
        self._lgb_dask_fit(
1564
1565
1566
1567
            model_factory=LGBMRanker,
            X=X,
            y=y,
            sample_weight=sample_weight,
1568
            init_score=init_score,
1569
            group=group,
1570
1571
1572
1573
1574
1575
1576
            eval_set=eval_set,
            eval_names=eval_names,
            eval_sample_weight=eval_sample_weight,
            eval_init_score=eval_init_score,
            eval_group=eval_group,
            eval_metric=eval_metric,
            eval_at=eval_at,
1577
            **kwargs,
1578
        )
1579
        return self
1580

1581
1582
1583
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
1584
1585
        sample_weight_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
        init_score_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
1586
        group_shape="Dask Array or Dask Series or None, optional (default=None)",
1587
1588
        eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
        eval_init_score_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1589
        eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1590
1591
    )

1592
    # DaskLGBMRanker does not support eval_class_weight or early stopping
1593
    _base_doc = _base_doc[: _base_doc.find("eval_class_weight :")] + _base_doc[_base_doc.find("eval_init_score :") :]
1594

1595
1596
1597
1598
    _base_doc = (
        _base_doc[: _base_doc.find("feature_name :")]
        + "eval_at : list or tuple of int, optional (default=(1, 2, 3, 4, 5))\n"
        + f"{' ':8}The evaluation positions of the specified metric.\n"
1599
        + f"{' ':4}{_base_doc[_base_doc.find('feature_name :') :]}"
1600
    )
1601
1602

    # DaskLGBMRanker support for callbacks and init_model is not tested
1603
    fit.__doc__ = f"""{_base_doc[: _base_doc.find("callbacks :")]}**kwargs
1604
        Other parameters passed through to ``LGBMRanker.fit()``.
1605

1606
1607
1608
1609
1610
    Returns
    -------
    self : lightgbm.DaskLGBMRanker
        Returns self.

1611
    {_lgbmmodel_doc_custom_eval_note}
1612
        """
1613

1614
1615
    def predict(
        self,
1616
        X: _DaskMatrixLike,  # type: ignore[override]
1617
1618
1619
1620
1621
1622
        raw_score: bool = False,
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        validate_features: bool = False,
1623
        **kwargs: Any,
1624
    ) -> dask_Array:
1625
        """Docstring is inherited from the lightgbm.LGBMRanker.predict."""
1626
1627
1628
1629
        return _predict(
            model=self.to_local(),
            data=X,
            client=_get_dask_client(self.client),
1630
1631
1632
1633
1634
1635
            raw_score=raw_score,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            validate_features=validate_features,
1636
            **kwargs,
1637
        )
1638

1639
1640
1641
1642
1643
1644
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees]",
1645
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]",
1646
    )
1647

1648
    def to_local(self) -> LGBMRanker:
1649
1650
1651
1652
1653
        """Create regular version of lightgbm.LGBMRanker from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMRanker
1654
            Local underlying model.
1655
        """
1656
        return self._lgb_dask_to_local(LGBMRanker)