"vscode:/vscode.git/clone" did not exist on "5d79ff20d1b7ae226531e2445b17d747b253a637"
dask.py 65.4 KB
Newer Older
1
# coding: utf-8
2
"""Distributed training with LightGBM and dask.distributed.
3

4
This module enables you to perform distributed training with LightGBM on
5
dask.Array and dask.DataFrame collections.
6
7

It is based on dask-lightgbm, which was based on dask-xgboost.
8
"""
9
import operator
10
import socket
11
from collections import defaultdict
12
from copy import deepcopy
13
from enum import Enum, auto
14
from functools import partial
15
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
16
17
18
from urllib.parse import urlparse

import numpy as np
19
20
import scipy.sparse as ss

21
from .basic import LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from .compat import (
    DASK_INSTALLED,
    PANDAS_INSTALLED,
    SKLEARN_INSTALLED,
    Client,
    Future,
    LGBMNotFittedError,
    concat,
    dask_Array,
    dask_array_from_delayed,
    dask_bag_from_delayed,
    dask_DataFrame,
    dask_Series,
    default_client,
    delayed,
    pd_DataFrame,
    pd_Series,
    wait,
)
from .sklearn import (
    LGBMClassifier,
    LGBMModel,
    LGBMRanker,
    LGBMRegressor,
    _LGBM_ScikitCustomObjectiveFunction,
    _LGBM_ScikitEvalMetricType,
    _lgbmmodel_doc_custom_eval_note,
    _lgbmmodel_doc_fit,
    _lgbmmodel_doc_predict,
)
52

53
__all__ = [
54
55
56
    "DaskLGBMClassifier",
    "DaskLGBMRanker",
    "DaskLGBMRegressor",
57
58
]

59
60
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
61
_DaskVectorLike = Union[dask_Array, dask_Series]
62
63
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
64

65

66
67
68
69
class _RemoteSocket:
    def acquire(self) -> int:
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
70
        self.socket.bind(("", 0))
71
        return self.socket.getsockname()[1]
72

73
74
    def release(self) -> None:
        self.socket.close()
75

76
77
78
79
80

def _acquire_port() -> Tuple[_RemoteSocket, int]:
    s = _RemoteSocket()
    port = s.acquire()
    return s, port
81

82

83
84
85
86
87
88
89
90
91
92
93
94
class _DatasetNames(Enum):
    """Placeholder names used by lightgbm.dask internals to say 'also evaluate the training data'.

    Avoid duplicating the training data when the validation set refers to elements of training data.
    """

    TRAINSET = auto()
    SAMPLE_WEIGHT = auto()
    INIT_SCORE = auto()
    GROUP = auto()


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def _get_dask_client(client: Optional[Client]) -> Client:
    """Choose a Dask client to use.

    Parameters
    ----------
    client : dask.distributed.Client or None
        Dask client.

    Returns
    -------
    client : dask.distributed.Client
        A Dask client.
    """
    if client is None:
        return default_client()
    else:
        return client


114
115
def _assign_open_ports_to_workers(
    client: Client,
116
117
    workers: List[str],
) -> Tuple[Dict[str, Future], Dict[str, int]]:
118
119
120
121
    """Assign an open port to each worker.

    Returns
    -------
122
123
    worker_to_socket_future: dict
        mapping from worker address to a future pointing to the remote socket.
124
    worker_to_port: dict
125
        mapping from worker address to an open port in the worker's host.
126
    """
127
128
129
130
131
132
    # Acquire port in worker
    worker_to_future = {}
    for worker in workers:
        worker_to_future[worker] = client.submit(
            _acquire_port,
            workers=[worker],
133
            allow_other_workers=False,
134
            pure=False,
135
        )
136
137
138
139
140
141
142
143
144
145
146
147

    # schedule futures to retrieve each element of the tuple
    worker_to_socket_future = {}
    worker_to_port_future = {}
    for worker, socket_future in worker_to_future.items():
        worker_to_socket_future[worker] = client.submit(operator.itemgetter(0), socket_future)
        worker_to_port_future[worker] = client.submit(operator.itemgetter(1), socket_future)

    # retrieve ports
    worker_to_port = client.gather(worker_to_port_future)

    return worker_to_socket_future, worker_to_port
148
149


150
def _concat(seq: List[_DaskPart]) -> _DaskPart:
151
152
    if isinstance(seq[0], np.ndarray):
        return np.concatenate(seq, axis=0)
153
    elif isinstance(seq[0], (pd_DataFrame, pd_Series)):
154
        return concat(seq, axis=0)
155
    elif isinstance(seq[0], ss.spmatrix):
156
        return ss.vstack(seq, format="csr")
157
    else:
158
159
160
        raise TypeError(
            f"Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got {type(seq[0]).__name__}."
        )
161
162


163
164
165
166
def _remove_list_padding(*args: Any) -> List[List[Any]]:
    return [[z for z in arg if z is not None] for arg in args]


167
def _pad_eval_names(lgbm_model: LGBMModel, required_names: List[str]) -> LGBMModel:
168
169
170
171
172
173
    """Append missing (key, value) pairs to a LightGBM model's evals_result_ and best_score_ OrderedDict attrs based on a set of required eval_set names.

    Allows users to rely on expected eval_set names being present when fitting DaskLGBM estimators with ``eval_set``.
    """
    for eval_name in required_names:
        if eval_name not in lgbm_model.evals_result_:
174
            lgbm_model.evals_result_[eval_name] = {}
175
        if eval_name not in lgbm_model.best_score_:
176
            lgbm_model.best_score_[eval_name] = {}
177
178
179
180

    return lgbm_model


181
182
183
184
def _train_part(
    params: Dict[str, Any],
    model_factory: Type[LGBMModel],
    list_of_parts: List[Dict[str, _DaskPart]],
185
186
187
    machines: str,
    local_listen_port: int,
    num_machines: int,
188
    return_model: bool,
189
    time_out: int,
190
    remote_socket: _RemoteSocket,
191
    **kwargs: Any,
192
) -> Optional[LGBMModel]:
193
    network_params = {
194
195
196
197
        "machines": machines,
        "local_listen_port": local_listen_port,
        "time_out": time_out,
        "num_machines": num_machines,
198
    }
199
200
    params.update(network_params)

201
202
    is_ranker = issubclass(model_factory, LGBMRanker)

203
    # Concatenate many parts into one
204
205
    data = _concat([x["data"] for x in list_of_parts])
    label = _concat([x["label"] for x in list_of_parts])
206

207
208
    if "weight" in list_of_parts[0]:
        weight = _concat([x["weight"] for x in list_of_parts])
209
210
211
    else:
        weight = None

212
213
    if "group" in list_of_parts[0]:
        group = _concat([x["group"] for x in list_of_parts])
214
215
    else:
        group = None
216

217
218
    if "init_score" in list_of_parts[0]:
        init_score = _concat([x["init_score"] for x in list_of_parts])
219
220
221
    else:
        init_score = None

222
    # construct local eval_set data.
223
224
225
    n_evals = max(len(x.get("eval_set", [])) for x in list_of_parts)
    eval_names = kwargs.pop("eval_names", None)
    eval_class_weight = kwargs.get("eval_class_weight")
226
227
228
229
230
231
232
    local_eval_set = None
    local_eval_names = None
    local_eval_sample_weight = None
    local_eval_init_score = None
    local_eval_group = None

    if n_evals:
233
234
        has_eval_sample_weight = any(x.get("eval_sample_weight") is not None for x in list_of_parts)
        has_eval_init_score = any(x.get("eval_init_score") is not None for x in list_of_parts)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

        local_eval_set = []
        evals_result_names = []
        if has_eval_sample_weight:
            local_eval_sample_weight = []
        if has_eval_init_score:
            local_eval_init_score = []
        if is_ranker:
            local_eval_group = []

        # store indices of eval_set components that were not contained within local parts.
        missing_eval_component_idx = []

        # consolidate parts of each individual eval component.
        for i in range(n_evals):
            x_e = []
            y_e = []
            w_e = []
            init_score_e = []
            g_e = []
            for part in list_of_parts:
256
                if not part.get("eval_set"):
257
258
259
260
261
262
263
                    continue

                # require that eval_name exists in evaluated result data in case dropped due to padding.
                # in distributed training the 'training' eval_set is not detected, will have name 'valid_<index>'.
                if eval_names:
                    evals_result_name = eval_names[i]
                else:
264
                    evals_result_name = f"valid_{i}"
265

266
                eval_set = part["eval_set"][i]
267
                if eval_set is _DatasetNames.TRAINSET:
268
269
                    x_e.append(part["data"])
                    y_e.append(part["label"])
270
271
272
273
274
275
276
                else:
                    x_e.extend(eval_set[0])
                    y_e.extend(eval_set[1])

                if evals_result_name not in evals_result_names:
                    evals_result_names.append(evals_result_name)

277
                eval_weight = part.get("eval_sample_weight")
278
279
                if eval_weight:
                    if eval_weight[i] is _DatasetNames.SAMPLE_WEIGHT:
280
                        w_e.append(part["weight"])
281
282
283
                    else:
                        w_e.extend(eval_weight[i])

284
                eval_init_score = part.get("eval_init_score")
285
286
                if eval_init_score:
                    if eval_init_score[i] is _DatasetNames.INIT_SCORE:
287
                        init_score_e.append(part["init_score"])
288
289
290
                    else:
                        init_score_e.extend(eval_init_score[i])

291
                eval_group = part.get("eval_group")
292
293
                if eval_group:
                    if eval_group[i] is _DatasetNames.GROUP:
294
                        g_e.append(part["group"])
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
                    else:
                        g_e.extend(eval_group[i])

            # filter padding from eval parts then _concat each eval_set component.
            x_e, y_e, w_e, init_score_e, g_e = _remove_list_padding(x_e, y_e, w_e, init_score_e, g_e)
            if x_e:
                local_eval_set.append((_concat(x_e), _concat(y_e)))
            else:
                missing_eval_component_idx.append(i)
                continue

            if w_e:
                local_eval_sample_weight.append(_concat(w_e))
            if init_score_e:
                local_eval_init_score.append(_concat(init_score_e))
            if g_e:
                local_eval_group.append(_concat(g_e))

        # reconstruct eval_set fit args/kwargs depending on which components of eval_set are on worker.
        eval_component_idx = [i for i in range(n_evals) if i not in missing_eval_component_idx]
        if eval_names:
            local_eval_names = [eval_names[i] for i in eval_component_idx]
        if eval_class_weight:
318
            kwargs["eval_class_weight"] = [eval_class_weight[i] for i in eval_component_idx]
319

320
    model = model_factory(**params)
321
322
    if remote_socket is not None:
        remote_socket.release()
323
    try:
324
        if is_ranker:
325
326
327
328
329
330
331
332
333
334
335
            model.fit(
                data,
                label,
                sample_weight=weight,
                init_score=init_score,
                group=group,
                eval_set=local_eval_set,
                eval_sample_weight=local_eval_sample_weight,
                eval_init_score=local_eval_init_score,
                eval_group=local_eval_group,
                eval_names=local_eval_names,
336
                **kwargs,
337
            )
338
        else:
339
340
341
342
343
344
345
346
347
            model.fit(
                data,
                label,
                sample_weight=weight,
                init_score=init_score,
                eval_set=local_eval_set,
                eval_sample_weight=local_eval_sample_weight,
                eval_init_score=local_eval_init_score,
                eval_names=local_eval_names,
348
                **kwargs,
349
            )
350

351
    finally:
352
353
        if getattr(model, "fitted_", False):
            model.booster_.free_network()
354

355
356
357
358
    if n_evals:
        # ensure that expected keys for evals_result_ and best_score_ exist regardless of padding.
        model = _pad_eval_names(model, required_names=evals_result_names)

359
360
361
    return model if return_model else None


362
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
363
364
    parts = data.to_delayed()
    if isinstance(parts, np.ndarray):
365
366
367
368
        if is_matrix:
            assert parts.shape[1] == 1
        else:
            assert parts.ndim == 1 or parts.shape[1] == 1
369
370
371
372
        parts = parts.flatten().tolist()
    return parts


373
def _machines_to_worker_map(machines: str, worker_addresses: Iterable[str]) -> Dict[str, int]:
374
375
376
377
378
379
380
381
382
383
    """Create a worker_map from machines list.

    Given ``machines`` and a list of Dask worker addresses, return a mapping where the keys are
    ``worker_addresses`` and the values are ports from ``machines``.

    Parameters
    ----------
    machines : str
        A comma-delimited list of workers, of the form ``ip1:port,ip2:port``.
    worker_addresses : list of str
384
        An iterable of Dask worker addresses, of the form ``{protocol}{hostname}:{port}``, where ``port`` is the port Dask's scheduler uses to talk to that worker.
385
386
387
388
389
390
391

    Returns
    -------
    result : Dict[str, int]
        Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
    """
    machine_addresses = machines.split(",")
392
393

    if len(set(machine_addresses)) != len(machine_addresses):
394
395
396
        raise ValueError(
            f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination."
        )
397

398
399
400
401
402
403
404
405
    machine_to_port = defaultdict(set)
    for address in machine_addresses:
        host, port = address.split(":")
        machine_to_port[host].add(int(port))

    out = {}
    for address in worker_addresses:
        worker_host = urlparse(address).hostname
406
407
        if not worker_host:
            raise ValueError(f"Could not parse host name from worker address '{address}'")
408
409
410
411
412
        out[address] = machine_to_port[worker_host].pop()

    return out


413
414
415
416
417
418
def _train(
    client: Client,
    data: _DaskMatrixLike,
    label: _DaskCollection,
    params: Dict[str, Any],
    model_factory: Type[LGBMModel],
419
    sample_weight: Optional[_DaskVectorLike] = None,
420
    init_score: Optional[_DaskCollection] = None,
421
    group: Optional[_DaskVectorLike] = None,
422
423
    eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
    eval_names: Optional[List[str]] = None,
424
    eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
425
    eval_class_weight: Optional[List[Union[dict, str]]] = None,
426
    eval_init_score: Optional[List[_DaskCollection]] = None,
427
    eval_group: Optional[List[_DaskVectorLike]] = None,
428
    eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
429
    eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None,
430
    **kwargs: Any,
431
) -> LGBMModel:
432
433
434
435
    """Inner train routine.

    Parameters
    ----------
436
437
    client : dask.distributed.Client
        Dask client.
438
    data : Dask Array or Dask DataFrame of shape = [n_samples, n_features]
439
        Input feature matrix.
440
    label : Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]
441
442
        The target values (class labels in classification, real numbers in regression).
    params : dict
443
        Parameters passed to constructor of the local underlying model.
444
    model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
445
        Class of the local underlying model.
446
    sample_weight : Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)
447
        Weights of training data. Weights should be non-negative.
448
    init_score : Dask Array or Dask Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task), or Dask Array or Dask DataFrame of shape = [n_samples, n_classes] (for multi-class task), or None, optional (default=None)
449
        Init score of training data.
450
    group : Dask Array or Dask Series or None, optional (default=None)
451
452
453
454
455
        Group/query data.
        Only used in the learning-to-rank task.
        sum(group) = n_samples.
        For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
        where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
456
    eval_set : list of (X, y) tuples of Dask data collections, or None, optional (default=None)
457
458
459
        List of (X, y) tuple pairs to use as validation sets.
        Note, that not all workers may receive chunks of every eval set within ``eval_set``. When the returned
        lightgbm estimator is not trained using any chunks of a particular eval set, its corresponding component
460
        of ``evals_result_`` and ``best_score_`` will be empty dictionaries.
461
    eval_names : list of str, or None, optional (default=None)
462
        Names of eval_set.
463
    eval_sample_weight : list of Dask Array or Dask Series, or None, optional (default=None)
464
        Weights for each validation set in eval_set. Weights should be non-negative.
465
466
    eval_class_weight : list of dict or str, or None, optional (default=None)
        Class weights, one dict or str for each validation set in eval_set.
467
    eval_init_score : list of Dask Array, Dask Series or Dask DataFrame (for multi-class task), or None, optional (default=None)
468
        Initial model score for each validation set in eval_set.
469
    eval_group : list of Dask Array or Dask Series, or None, optional (default=None)
470
        Group/query for each validation set in eval_set.
471
472
    eval_metric : str, callable, list or None, optional (default=None)
        If str, it should be a built-in evaluation metric to use.
473
474
475
476
        If callable, it should be a custom evaluation metric, see note below for more details.
        If list, it can be a list of built-in metrics, a list of custom evaluation metrics, or a mix of both.
        In either case, the ``metric`` from the Dask model parameters (or inferred from the objective) will be evaluated and used as well.
        Default: 'l2' for DaskLGBMRegressor, 'binary(multi)_logloss' for DaskLGBMClassifier, 'ndcg' for DaskLGBMRanker.
477
    eval_at : list or tuple of int, optional (default=None)
478
        The evaluation positions of the specified ranking metric.
479
480
481
482
483
484
485
    **kwargs
        Other parameters passed to ``fit`` method of the local underlying model.

    Returns
    -------
    model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
        Returns fitted underlying model.
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514

    Note
    ----

    This method handles setting up the following network parameters based on information
    about the Dask cluster referenced by ``client``.

    * ``local_listen_port``: port that each LightGBM worker opens a listening socket on,
            to accept connections from other workers. This can differ from LightGBM worker
            to LightGBM worker, but does not have to.
    * ``machines``: a comma-delimited list of all workers in the cluster, in the
            form ``ip:port,ip:port``. If running multiple Dask workers on the same host, use different
            ports for each worker. For example, for ``LocalCluster(n_workers=3)``, you might
            pass ``"127.0.0.1:12400,127.0.0.1:12401,127.0.0.1:12402"``.
    * ``num_machines``: number of LightGBM workers.
    * ``timeout``: time in minutes to wait before closing unused sockets.

    The default behavior of this function is to generate ``machines`` from the list of
    Dask workers which hold some piece of the training data, and to search for an open
    port on each worker to be used as ``local_listen_port``.

    If ``machines`` is provided explicitly in ``params``, this function uses the hosts
    and ports in that list directly, and does not do any searching. This means that if
    any of the Dask workers are missing from the list or any of those ports are not free
    when training starts, training will fail.

    If ``local_listen_port`` is provided in ``params`` and ``machines`` is not, this function
    constructs ``machines`` from the list of Dask workers which hold some piece of the
    training data, assuming that each one will use the same ``local_listen_port``.
515
    """
516
517
    params = deepcopy(params)

518
    # capture whether local_listen_port or its aliases were provided
519
    listen_port_in_params = any(alias in params for alias in _ConfigAliases.get("local_listen_port"))
520
521

    # capture whether machines or its aliases were provided
522
    machines_in_params = any(alias in params for alias in _ConfigAliases.get("machines"))
523
524
525
526

    params = _choose_param_value(
        main_param_name="tree_learner",
        params=params,
527
        default_value="data",
528
529
    )
    allowed_tree_learners = {
530
531
532
533
534
535
        "data",
        "data_parallel",
        "feature",
        "feature_parallel",
        "voting",
        "voting_parallel",
536
537
    }
    if params["tree_learner"] not in allowed_tree_learners:
538
539
540
541
        _log_warning(
            f'Parameter tree_learner set to {params["tree_learner"]}, which is not allowed. Using "data" as default'
        )
        params["tree_learner"] = "data"
542
543
544
545

    # Some passed-in parameters can be removed:
    #   * 'num_machines': set automatically from Dask worker list
    #   * 'num_threads': overridden to match nthreads on each Dask process
546
    for param_alias in _ConfigAliases.get("num_machines", "num_threads"):
547
548
549
        if param_alias in params:
            _log_warning(f"Parameter {param_alias} will be ignored.")
            params.pop(param_alias)
550

551
    # Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
552
553
    data_parts = _split_to_parts(data=data, is_matrix=True)
    label_parts = _split_to_parts(data=label, is_matrix=False)
554
    parts = [{"data": x, "label": y} for (x, y) in zip(data_parts, label_parts)]
555
    n_parts = len(parts)
556
557
558

    if sample_weight is not None:
        weight_parts = _split_to_parts(data=sample_weight, is_matrix=False)
559
        for i in range(n_parts):
560
            parts[i]["weight"] = weight_parts[i]
561
562
563

    if group is not None:
        group_parts = _split_to_parts(data=group, is_matrix=False)
564
        for i in range(n_parts):
565
            parts[i]["group"] = group_parts[i]
566

567
568
569
    if init_score is not None:
        init_score_parts = _split_to_parts(data=init_score, is_matrix=False)
        for i in range(n_parts):
570
            parts[i]["init_score"] = init_score_parts[i]
571

572
573
574
575
576
577
578
    # evals_set will to be re-constructed into smaller lists of (X, y) tuples, where
    # X and y are each delayed sub-lists of original eval dask Collections.
    if eval_set:
        # find maximum number of parts in an individual eval set so that we can
        # pad eval sets when they come in different sizes.
        n_largest_eval_parts = max(x[0].npartitions for x in eval_set)

579
        eval_sets: Dict[
580
            int, List[Union[_DatasetNames, Tuple[List[Optional[_DaskMatrixLike]], List[Optional[_DaskVectorLike]]]]]
581
        ] = defaultdict(list)
582
        if eval_sample_weight:
583
584
585
            eval_sample_weights: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskVectorLike]]]]] = defaultdict(
                list
            )
586
        if eval_group:
587
            eval_groups: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskVectorLike]]]]] = defaultdict(list)
588
        if eval_init_score:
589
            eval_init_scores: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskMatrixLike]]]]] = defaultdict(list)
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616

        for i, (X_eval, y_eval) in enumerate(eval_set):
            n_this_eval_parts = X_eval.npartitions

            # when individual eval set is equivalent to training data, skip recomputing parts.
            if X_eval is data and y_eval is label:
                for parts_idx in range(n_parts):
                    eval_sets[parts_idx].append(_DatasetNames.TRAINSET)
            else:
                eval_x_parts = _split_to_parts(data=X_eval, is_matrix=True)
                eval_y_parts = _split_to_parts(data=y_eval, is_matrix=False)
                for j in range(n_largest_eval_parts):
                    parts_idx = j % n_parts

                    # add None-padding for individual eval_set member if it is smaller than the largest member.
                    if j < n_this_eval_parts:
                        x_e = eval_x_parts[j]
                        y_e = eval_y_parts[j]
                    else:
                        x_e = None
                        y_e = None

                    if j < n_parts:
                        # first time a chunk of this eval set is added to this part.
                        eval_sets[parts_idx].append(([x_e], [y_e]))
                    else:
                        # append additional chunks of this eval set to this part.
617
618
                        eval_sets[parts_idx][-1][0].append(x_e)  # type: ignore[index, union-attr]
                        eval_sets[parts_idx][-1][1].append(y_e)  # type: ignore[index, union-attr]
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637

            if eval_sample_weight:
                if eval_sample_weight[i] is sample_weight:
                    for parts_idx in range(n_parts):
                        eval_sample_weights[parts_idx].append(_DatasetNames.SAMPLE_WEIGHT)
                else:
                    eval_w_parts = _split_to_parts(data=eval_sample_weight[i], is_matrix=False)

                    # ensure that all evaluation parts map uniquely to one part.
                    for j in range(n_largest_eval_parts):
                        if j < n_this_eval_parts:
                            w_e = eval_w_parts[j]
                        else:
                            w_e = None

                        parts_idx = j % n_parts
                        if j < n_parts:
                            eval_sample_weights[parts_idx].append([w_e])
                        else:
638
                            eval_sample_weights[parts_idx][-1].append(w_e)  # type: ignore[union-attr]
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655

            if eval_init_score:
                if eval_init_score[i] is init_score:
                    for parts_idx in range(n_parts):
                        eval_init_scores[parts_idx].append(_DatasetNames.INIT_SCORE)
                else:
                    eval_init_score_parts = _split_to_parts(data=eval_init_score[i], is_matrix=False)
                    for j in range(n_largest_eval_parts):
                        if j < n_this_eval_parts:
                            init_score_e = eval_init_score_parts[j]
                        else:
                            init_score_e = None

                        parts_idx = j % n_parts
                        if j < n_parts:
                            eval_init_scores[parts_idx].append([init_score_e])
                        else:
656
                            eval_init_scores[parts_idx][-1].append(init_score_e)  # type: ignore[union-attr]
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673

            if eval_group:
                if eval_group[i] is group:
                    for parts_idx in range(n_parts):
                        eval_groups[parts_idx].append(_DatasetNames.GROUP)
                else:
                    eval_g_parts = _split_to_parts(data=eval_group[i], is_matrix=False)
                    for j in range(n_largest_eval_parts):
                        if j < n_this_eval_parts:
                            g_e = eval_g_parts[j]
                        else:
                            g_e = None

                        parts_idx = j % n_parts
                        if j < n_parts:
                            eval_groups[parts_idx].append([g_e])
                        else:
674
                            eval_groups[parts_idx][-1].append(g_e)  # type: ignore[union-attr]
675
676
677

        # assign sub-eval_set components to worker parts.
        for parts_idx, e_set in eval_sets.items():
678
            parts[parts_idx]["eval_set"] = e_set
679
            if eval_sample_weight:
680
                parts[parts_idx]["eval_sample_weight"] = eval_sample_weights[parts_idx]
681
            if eval_init_score:
682
                parts[parts_idx]["eval_init_score"] = eval_init_scores[parts_idx]
683
            if eval_group:
684
                parts[parts_idx]["eval_group"] = eval_groups[parts_idx]
685

686
    # Start computation in the background
687
    parts = list(map(delayed, parts))
688
689
690
691
    parts = client.compute(parts)
    wait(parts)

    for part in parts:
692
        if part.status == "error":  # type: ignore
693
694
            # trigger error locally
            return part  # type: ignore[return-value]
695
696

    # Find locations of all parts and map them to particular Dask workers
697
    key_to_part_dict = {part.key: part for part in parts}  # type: ignore
698
699
700
701
702
    who_has = client.who_has(parts)
    worker_map = defaultdict(list)
    for key, workers in who_has.items():
        worker_map[next(iter(workers))].append(key_to_part_dict[key])

703
704
705
706
707
708
    # Check that all workers were provided some of eval_set. Otherwise warn user that validation
    # data artifacts may not be populated depending on worker returning final estimator.
    if eval_set:
        for worker in worker_map:
            has_eval_set = False
            for part in worker_map[worker]:
709
                if "eval_set" in part.result():  # type: ignore[attr-defined]
710
711
712
713
714
715
716
717
718
719
720
                    has_eval_set = True
                    break

            if not has_eval_set:
                _log_warning(
                    f"Worker {worker} was not allocated eval_set data. Therefore evals_result_ and best_score_ data may be unreliable. "
                    "Try rebalancing data across workers."
                )

    # assign general validation set settings to fit kwargs.
    if eval_names:
721
        kwargs["eval_names"] = eval_names
722
    if eval_class_weight:
723
        kwargs["eval_class_weight"] = eval_class_weight
724
    if eval_metric:
725
        kwargs["eval_metric"] = eval_metric
726
    if eval_at:
727
        kwargs["eval_at"] = eval_at
728

729
730
731
    master_worker = next(iter(worker_map))
    worker_ncores = client.ncores()

732
733
734
735
736
    # resolve aliases for network parameters and pop the result off params.
    # these values are added back in calls to `_train_part()`
    params = _choose_param_value(
        main_param_name="local_listen_port",
        params=params,
737
        default_value=12400,
738
    )
739
740
741
742
743
    local_listen_port = params.pop("local_listen_port")

    params = _choose_param_value(
        main_param_name="machines",
        params=params,
744
        default_value=None,
745
746
747
748
    )
    machines = params.pop("machines")

    # figure out network params
749
    worker_to_socket_future: Dict[str, Future] = {}
750
751
752
753
754
    worker_addresses = worker_map.keys()
    if machines is not None:
        _log_info("Using passed-in 'machines' parameter")
        worker_address_to_port = _machines_to_worker_map(
            machines=machines,
755
            worker_addresses=worker_addresses,
756
757
758
759
        )
    else:
        if listen_port_in_params:
            _log_info("Using passed-in 'local_listen_port' for all workers")
760
            unique_hosts = {urlparse(a).hostname for a in worker_addresses}
761
762
763
764
765
766
767
768
            if len(unique_hosts) < len(worker_addresses):
                msg = (
                    "'local_listen_port' was provided in Dask training parameters, but at least one "
                    "machine in the cluster has multiple Dask worker processes running on it. Please omit "
                    "'local_listen_port' or pass 'machines'."
                )
                raise LightGBMError(msg)

769
            worker_address_to_port = {address: local_listen_port for address in worker_addresses}
770
771
        else:
            _log_info("Finding random open ports for workers")
772
773
774
            worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(
                client, list(worker_map.keys())
            )
775

776
777
778
        machines = ",".join(
            [f"{urlparse(worker_address).hostname}:{port}" for worker_address, port in worker_address_to_port.items()]
        )
779
780

    num_machines = len(worker_address_to_port)
781

782
    # Tell each worker to train on the parts that it has locally
783
    #
784
    # This code treats ``_train_part()`` calls as not "pure" because:
785
    #     1. there is randomness in the training process unless parameters ``seed``
786
    #        and ``deterministic`` are set
787
788
789
    #     2. even with those parameters set, the output of one ``_train_part()`` call
    #        relies on global state (it and all the other LightGBM training processes
    #        coordinate with each other)
790
791
792
793
    futures_classifiers = [
        client.submit(
            _train_part,
            model_factory=model_factory,
794
            params={**params, "num_threads": worker_ncores[worker]},
795
            list_of_parts=list_of_parts,
796
797
798
            machines=machines,
            local_listen_port=worker_address_to_port[worker],
            num_machines=num_machines,
799
            time_out=params.get("time_out", 120),
800
            remote_socket=worker_to_socket_future.get(worker, None),
801
            return_model=(worker == master_worker),
802
803
804
            workers=[worker],
            allow_other_workers=False,
            pure=False,
805
            **kwargs,
806
807
808
        )
        for worker, list_of_parts in worker_map.items()
    ]
809
810
811

    results = client.gather(futures_classifiers)
    results = [v for v in results if v]
812
813
814
    model = results[0]

    # if network parameters were changed during training, remove them from the
Andrew Ziem's avatar
Andrew Ziem committed
815
    # returned model so that they're generated dynamically on every run based
816
817
818
    # on the Dask cluster you're connected to and which workers have pieces of
    # the training data
    if not listen_port_in_params:
819
        for param in _ConfigAliases.get("local_listen_port"):
820
821
822
            model._other_params.pop(param, None)

    if not machines_in_params:
823
        for param in _ConfigAliases.get("machines"):
824
825
            model._other_params.pop(param, None)

826
    for param in _ConfigAliases.get("num_machines", "timeout"):
827
828
829
        model._other_params.pop(param, None)

    return model
830
831


832
833
834
835
836
837
838
def _predict_part(
    part: _DaskPart,
    model: LGBMModel,
    raw_score: bool,
    pred_proba: bool,
    pred_leaf: bool,
    pred_contrib: bool,
839
    **kwargs: Any,
840
) -> _DaskPart:
841
    result: _DaskPart
842
    if part.shape[0] == 0:
843
        result = np.array([])
844
845
    elif pred_proba:
        result = model.predict_proba(
846
            part,
847
848
849
            raw_score=raw_score,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
850
            **kwargs,
851
        )
852
    else:
853
        result = model.predict(
854
            part,
855
856
857
            raw_score=raw_score,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
858
            **kwargs,
859
        )
860

861
    # dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series
862
    if isinstance(part, pd_DataFrame):
863
        if len(result.shape) == 2:
864
            result = pd_DataFrame(result, index=part.index)
865
        else:
866
            result = pd_Series(result, index=part.index, name="predictions")
867
868
869
870

    return result


871
872
873
def _predict(
    model: LGBMModel,
    data: _DaskMatrixLike,
874
    client: Client,
875
876
877
878
879
    raw_score: bool = False,
    pred_proba: bool = False,
    pred_leaf: bool = False,
    pred_contrib: bool = False,
    dtype: _PredictionDtype = np.float32,
880
    **kwargs: Any,
881
) -> Union[dask_Array, List[dask_Array]]:
882
883
884
885
    """Inner predict routine.

    Parameters
    ----------
886
    model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
887
        Fitted underlying model.
888
    data : Dask Array or Dask DataFrame of shape = [n_samples, n_features]
889
        Input feature matrix.
890
891
    raw_score : bool, optional (default=False)
        Whether to predict raw scores.
892
893
894
895
896
897
    pred_proba : bool, optional (default=False)
        Should method return results of ``predict_proba`` (``pred_proba=True``) or ``predict`` (``pred_proba=False``).
    pred_leaf : bool, optional (default=False)
        Whether to predict leaf index.
    pred_contrib : bool, optional (default=False)
        Whether to predict feature contributions.
898
    dtype : np.dtype, optional (default=np.float32)
899
        Dtype of the output.
900
    **kwargs
901
        Other parameters passed to ``predict`` or ``predict_proba`` method.
902
903
904

    Returns
    -------
905
    predicted_result : Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]
906
        The predicted values.
907
    X_leaves : Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
908
        If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
909
    X_SHAP_values : Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]
910
        If ``pred_contrib=True``, the feature contributions for each sample.
911
    """
912
    if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
913
        raise LightGBMError("dask, pandas and scikit-learn are required for lightgbm.dask")
914
    if isinstance(data, dask_DataFrame):
915
916
917
918
919
920
921
        return data.map_partitions(
            _predict_part,
            model=model,
            raw_score=raw_score,
            pred_proba=pred_proba,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
922
            **kwargs,
923
        ).values
924
    elif isinstance(data, dask_Array):
925
926
        # for multi-class classification with sparse matrices, pred_contrib predictions
        # are returned as a list of sparse matrices (one per class)
927
        num_classes = model._n_classes
928

929
        if num_classes > 2 and pred_contrib and isinstance(data._meta, ss.spmatrix):
930
931
932
933
934
935
936
            predict_function = partial(
                _predict_part,
                model=model,
                raw_score=False,
                pred_proba=pred_proba,
                pred_leaf=False,
                pred_contrib=True,
937
                **kwargs,
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
            )

            delayed_chunks = data.to_delayed()
            bag = dask_bag_from_delayed(delayed_chunks[:, 0])

            @delayed
            def _extract(items: List[Any], i: int) -> Any:
                return items[i]

            preds = bag.map_partitions(predict_function)

            # pred_contrib output will have one column per feature,
            # plus one more for the base value
            num_cols = model.n_features_ + 1

            nrows_per_chunk = data.chunks[0]
954
            out: List[List[dask_Array]] = [[] for _ in range(num_classes)]
955
956
957
958
959
960
961
962
963

            # need to tell Dask the expected type and shape of individual preds
            pred_meta = data._meta

            for j, partition in enumerate(preds.to_delayed()):
                for i in range(num_classes):
                    part = dask_array_from_delayed(
                        value=_extract(partition, i),
                        shape=(nrows_per_chunk[j], num_cols),
964
                        meta=pred_meta,
965
966
967
968
969
970
                    )
                    out[i].append(part)

            # by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix
            # the code below is used instead to ensure that the sparse type is preserved during concatentation
            if isinstance(pred_meta, ss.csr_matrix):
971
                concat_fn = partial(ss.vstack, format="csr")
972
            elif isinstance(pred_meta, ss.csc_matrix):
973
                concat_fn = partial(ss.vstack, format="csc")
974
975
976
977
978
            else:
                concat_fn = ss.vstack

            # At this point, `out` is a list of lists of delayeds (each of which points to a matrix).
            # Concatenate them to return a list of Dask Arrays.
979
            out_arrays: List[dask_Array] = []
980
            for i in range(num_classes):
981
982
983
984
                out_arrays.append(
                    dask_array_from_delayed(
                        value=delayed(concat_fn)(out[i]),
                        shape=(data.shape[0], num_cols),
985
                        meta=pred_meta,
986
                    )
987
988
                )

989
            return out_arrays
990

991
992
        data_row = client.compute(data[[0]]).result()
        predict_fn = partial(
993
994
995
996
997
998
            _predict_part,
            model=model,
            raw_score=raw_score,
            pred_proba=pred_proba,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
999
1000
1001
            **kwargs,
        )
        pred_row = predict_fn(data_row)
1002
        chunks: Tuple[int, ...] = (data.chunks[0],)
1003
1004
1005
1006
        map_blocks_kwargs = {}
        if len(pred_row.shape) > 1:
            chunks += (pred_row.shape[1],)
        else:
1007
            map_blocks_kwargs["drop_axis"] = 1
1008
1009
1010
1011
        return data.map_blocks(
            predict_fn,
            chunks=chunks,
            meta=pred_row,
1012
            dtype=dtype,
1013
            **map_blocks_kwargs,
1014
        )
1015
    else:
1016
        raise TypeError(f"Data must be either Dask Array or Dask DataFrame. Got {type(data).__name__}.")
1017
1018


1019
class _DaskLGBMModel:
1020
1021
    @property
    def client_(self) -> Client:
1022
        """:obj:`dask.distributed.Client`: Dask client.
1023
1024
1025
1026
1027

        This property can be passed in the constructor or updated
        with ``model.set_params(client=client)``.
        """
        if not getattr(self, "fitted_", False):
1028
            raise LGBMNotFittedError("Cannot access property client_ before calling fit().")
1029
1030
1031

        return _get_dask_client(client=self.client)

1032
    def _lgb_dask_getstate(self) -> Dict[Any, Any]:
1033
1034
        """Remove un-picklable attributes before serialization."""
        client = self.__dict__.pop("client", None)
1035
        self._other_params.pop("client", None)  # type: ignore[attr-defined]
1036
        out = deepcopy(self.__dict__)
1037
        out.update({"client": None})
1038
1039
1040
        self.client = client
        return out

1041
    def _lgb_dask_fit(
1042
1043
1044
1045
        self,
        model_factory: Type[LGBMModel],
        X: _DaskMatrixLike,
        y: _DaskCollection,
1046
        sample_weight: Optional[_DaskVectorLike] = None,
1047
        init_score: Optional[_DaskCollection] = None,
1048
        group: Optional[_DaskVectorLike] = None,
1049
1050
        eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
        eval_names: Optional[List[str]] = None,
1051
        eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
1052
        eval_class_weight: Optional[List[Union[dict, str]]] = None,
1053
        eval_init_score: Optional[List[_DaskCollection]] = None,
1054
        eval_group: Optional[List[_DaskVectorLike]] = None,
1055
        eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
1056
        eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None,
1057
        **kwargs: Any,
1058
    ) -> "_DaskLGBMModel":
1059
        if not DASK_INSTALLED:
1060
            raise LightGBMError("dask is required for lightgbm.dask")
1061
        if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
1062
            raise LightGBMError("dask, pandas and scikit-learn are required for lightgbm.dask")
1063

1064
        params = self.get_params(True)  # type: ignore[attr-defined]
1065
        params.pop("client", None)
1066
1067

        model = _train(
1068
            client=_get_dask_client(self.client),
1069
1070
1071
1072
1073
            data=X,
            label=y,
            params=params,
            model_factory=model_factory,
            sample_weight=sample_weight,
1074
            init_score=init_score,
1075
            group=group,
1076
1077
1078
1079
1080
1081
1082
1083
            eval_set=eval_set,
            eval_names=eval_names,
            eval_sample_weight=eval_sample_weight,
            eval_class_weight=eval_class_weight,
            eval_init_score=eval_init_score,
            eval_group=eval_group,
            eval_metric=eval_metric,
            eval_at=eval_at,
1084
            **kwargs,
1085
        )
1086

1087
1088
        self.set_params(**model.get_params())  # type: ignore[attr-defined]
        self._lgb_dask_copy_extra_params(model, self)  # type: ignore[attr-defined]
1089
1090
1091

        return self

1092
    def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
1093
        params = self.get_params()  # type: ignore[attr-defined]
1094
1095
        params.pop("client", None)
        model = model_factory(**params)
1096
        self._lgb_dask_copy_extra_params(self, model)
1097
        model._other_params.pop("client", None)
1098
1099
1100
        return model

    @staticmethod
1101
1102
1103
1104
    def _lgb_dask_copy_extra_params(
        source: Union["_DaskLGBMModel", LGBMModel],
        dest: Union["_DaskLGBMModel", LGBMModel],
    ) -> None:
1105
        params = source.get_params()  # type: ignore[union-attr]
1106
1107
1108
        attributes = source.__dict__
        extra_param_names = set(attributes.keys()).difference(params.keys())
        for name in extra_param_names:
1109
            setattr(dest, name, attributes[name])
1110
1111


1112
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
1113
1114
    """Distributed version of lightgbm.LGBMClassifier."""

1115
1116
    def __init__(
        self,
1117
        boosting_type: str = "gbdt",
1118
1119
1120
1121
1122
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
1123
        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1124
        class_weight: Optional[Union[dict, str]] = None,
1125
        min_split_gain: float = 0.0,
1126
1127
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
1128
        subsample: float = 1.0,
1129
        subsample_freq: int = 0,
1130
1131
1132
1133
        colsample_bytree: float = 1.0,
        reg_alpha: float = 0.0,
        reg_lambda: float = 0.0,
        random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
1134
        n_jobs: Optional[int] = None,
1135
        importance_type: str = "split",
1136
        client: Optional[Client] = None,
1137
        **kwargs: Any,
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    ):
        """Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            importance_type=importance_type,
1161
            **kwargs,
1162
1163
1164
        )

    _base_doc = LGBMClassifier.__init__.__doc__
1165
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs")  # type: ignore
1166
    __init__.__doc__ = f"""
1167
1168
1169
1170
        {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
        {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
        {_kwargs}{_after_kwargs}
        """
1171
1172

    def __getstate__(self) -> Dict[Any, Any]:
1173
        return self._lgb_dask_getstate()
1174

1175
    def fit(  # type: ignore[override]
1176
1177
1178
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
1179
        sample_weight: Optional[_DaskVectorLike] = None,
1180
        init_score: Optional[_DaskCollection] = None,
1181
1182
        eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
        eval_names: Optional[List[str]] = None,
1183
        eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
1184
        eval_class_weight: Optional[List[Union[dict, str]]] = None,
1185
        eval_init_score: Optional[List[_DaskCollection]] = None,
1186
        eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
1187
        **kwargs: Any,
1188
    ) -> "DaskLGBMClassifier":
1189
        """Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
1190
        self._lgb_dask_fit(
1191
1192
1193
1194
            model_factory=LGBMClassifier,
            X=X,
            y=y,
            sample_weight=sample_weight,
1195
            init_score=init_score,
1196
1197
1198
1199
1200
1201
            eval_set=eval_set,
            eval_names=eval_names,
            eval_sample_weight=eval_sample_weight,
            eval_class_weight=eval_class_weight,
            eval_init_score=eval_init_score,
            eval_metric=eval_metric,
1202
            **kwargs,
1203
        )
1204
        return self
1205

1206
1207
1208
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
1209
        sample_weight_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
1210
        init_score_shape="Dask Array or Dask Series of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task), or Dask Array or Dask DataFrame of shape = [n_samples, n_classes] (for multi-class task), or None, optional (default=None)",
1211
        group_shape="Dask Array or Dask Series or None, optional (default=None)",
1212
        eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1213
        eval_init_score_shape="list of Dask Array, Dask Series or Dask DataFrame (for multi-class task), or None, optional (default=None)",
1214
        eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1215
1216
    )

1217
    # DaskLGBMClassifier does not support group, eval_group.
1218
    _base_doc = _base_doc[: _base_doc.find("group :")] + _base_doc[_base_doc.find("eval_set :") :]
1219

1220
    _base_doc = _base_doc[: _base_doc.find("eval_group :")] + _base_doc[_base_doc.find("eval_metric :") :]
1221

1222
    # DaskLGBMClassifier support for callbacks and init_model is not tested
1223
1224
    fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs
        Other parameters passed through to ``LGBMClassifier.fit()``.
1225

1226
1227
1228
1229
1230
    Returns
    -------
    self : lightgbm.DaskLGBMClassifier
        Returns self.

1231
    {_lgbmmodel_doc_custom_eval_note}
1232
        """
1233

1234
1235
    def predict(
        self,
1236
        X: _DaskMatrixLike,  # type: ignore[override]
1237
1238
1239
1240
1241
1242
        raw_score: bool = False,
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        validate_features: bool = False,
1243
        **kwargs: Any,
1244
    ) -> dask_Array:
1245
        """Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
1246
1247
1248
1249
        return _predict(
            model=self.to_local(),
            data=X,
            dtype=self.classes_.dtype,
1250
            client=_get_dask_client(self.client),
1251
1252
1253
1254
1255
1256
            raw_score=raw_score,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            validate_features=validate_features,
1257
            **kwargs,
1258
1259
        )

1260
1261
1262
1263
1264
1265
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
1266
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]",
1267
    )
1268

1269
1270
    def predict_proba(
        self,
1271
        X: _DaskMatrixLike,  # type: ignore[override]
1272
1273
1274
1275
1276
1277
        raw_score: bool = False,
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        validate_features: bool = False,
1278
        **kwargs: Any,
1279
    ) -> dask_Array:
1280
        """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
1281
1282
1283
1284
        return _predict(
            model=self.to_local(),
            data=X,
            pred_proba=True,
1285
            client=_get_dask_client(self.client),
1286
1287
1288
1289
1290
1291
            raw_score=raw_score,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            validate_features=validate_features,
1292
            **kwargs,
1293
1294
        )

1295
1296
1297
1298
    predict_proba.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted probability for each class for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_probability",
1299
        predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
1300
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
1301
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]",
1302
    )
1303

1304
    def to_local(self) -> LGBMClassifier:
1305
1306
1307
1308
1309
        """Create regular version of lightgbm.LGBMClassifier from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMClassifier
1310
            Local underlying model.
1311
        """
1312
        return self._lgb_dask_to_local(LGBMClassifier)
1313
1314


1315
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
1316
    """Distributed version of lightgbm.LGBMRegressor."""
1317

1318
1319
    def __init__(
        self,
1320
        boosting_type: str = "gbdt",
1321
1322
1323
1324
1325
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
1326
        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1327
        class_weight: Optional[Union[dict, str]] = None,
1328
        min_split_gain: float = 0.0,
1329
1330
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
1331
        subsample: float = 1.0,
1332
        subsample_freq: int = 0,
1333
1334
1335
1336
        colsample_bytree: float = 1.0,
        reg_alpha: float = 0.0,
        reg_lambda: float = 0.0,
        random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
1337
        n_jobs: Optional[int] = None,
1338
        importance_type: str = "split",
1339
        client: Optional[Client] = None,
1340
        **kwargs: Any,
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
    ):
        """Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            importance_type=importance_type,
1364
            **kwargs,
1365
1366
1367
        )

    _base_doc = LGBMRegressor.__init__.__doc__
1368
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs")  # type: ignore
1369
    __init__.__doc__ = f"""
1370
1371
1372
1373
        {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
        {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
        {_kwargs}{_after_kwargs}
        """
1374

1375
    def __getstate__(self) -> Dict[Any, Any]:
1376
        return self._lgb_dask_getstate()
1377

1378
    def fit(  # type: ignore[override]
1379
1380
1381
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
1382
1383
        sample_weight: Optional[_DaskVectorLike] = None,
        init_score: Optional[_DaskVectorLike] = None,
1384
1385
        eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
        eval_names: Optional[List[str]] = None,
1386
1387
        eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
        eval_init_score: Optional[List[_DaskVectorLike]] = None,
1388
        eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
1389
        **kwargs: Any,
1390
    ) -> "DaskLGBMRegressor":
1391
        """Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
1392
        self._lgb_dask_fit(
1393
1394
1395
1396
            model_factory=LGBMRegressor,
            X=X,
            y=y,
            sample_weight=sample_weight,
1397
            init_score=init_score,
1398
1399
1400
1401
1402
            eval_set=eval_set,
            eval_names=eval_names,
            eval_sample_weight=eval_sample_weight,
            eval_init_score=eval_init_score,
            eval_metric=eval_metric,
1403
            **kwargs,
1404
        )
1405
        return self
1406

1407
1408
1409
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
1410
1411
        sample_weight_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
        init_score_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
1412
        group_shape="Dask Array or Dask Series or None, optional (default=None)",
1413
1414
        eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
        eval_init_score_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1415
        eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1416
1417
    )

1418
    # DaskLGBMRegressor does not support group, eval_class_weight, eval_group.
1419
    _base_doc = _base_doc[: _base_doc.find("group :")] + _base_doc[_base_doc.find("eval_set :") :]
1420

1421
    _base_doc = _base_doc[: _base_doc.find("eval_class_weight :")] + _base_doc[_base_doc.find("eval_init_score :") :]
1422

1423
    _base_doc = _base_doc[: _base_doc.find("eval_group :")] + _base_doc[_base_doc.find("eval_metric :") :]
1424

1425
    # DaskLGBMRegressor support for callbacks and init_model is not tested
1426
1427
    fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs
        Other parameters passed through to ``LGBMRegressor.fit()``.
1428

1429
1430
1431
1432
1433
    Returns
    -------
    self : lightgbm.DaskLGBMRegressor
        Returns self.

1434
    {_lgbmmodel_doc_custom_eval_note}
1435
        """
1436

1437
1438
    def predict(
        self,
1439
        X: _DaskMatrixLike,  # type: ignore[override]
1440
1441
1442
1443
1444
1445
        raw_score: bool = False,
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        validate_features: bool = False,
1446
        **kwargs: Any,
1447
    ) -> dask_Array:
1448
        """Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
1449
1450
1451
        return _predict(
            model=self.to_local(),
            data=X,
1452
            client=_get_dask_client(self.client),
1453
1454
1455
1456
1457
1458
            raw_score=raw_score,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            validate_features=validate_features,
1459
            **kwargs,
1460
1461
        )

1462
1463
1464
1465
1466
1467
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees]",
1468
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]",
1469
    )
1470

1471
    def to_local(self) -> LGBMRegressor:
1472
1473
1474
1475
1476
        """Create regular version of lightgbm.LGBMRegressor from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMRegressor
1477
            Local underlying model.
1478
        """
1479
        return self._lgb_dask_to_local(LGBMRegressor)
1480
1481


1482
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
1483
    """Distributed version of lightgbm.LGBMRanker."""
1484

1485
1486
    def __init__(
        self,
1487
        boosting_type: str = "gbdt",
1488
1489
1490
1491
1492
        num_leaves: int = 31,
        max_depth: int = -1,
        learning_rate: float = 0.1,
        n_estimators: int = 100,
        subsample_for_bin: int = 200000,
1493
        objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1494
        class_weight: Optional[Union[dict, str]] = None,
1495
        min_split_gain: float = 0.0,
1496
1497
        min_child_weight: float = 1e-3,
        min_child_samples: int = 20,
1498
        subsample: float = 1.0,
1499
        subsample_freq: int = 0,
1500
1501
1502
1503
        colsample_bytree: float = 1.0,
        reg_alpha: float = 0.0,
        reg_lambda: float = 0.0,
        random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
1504
        n_jobs: Optional[int] = None,
1505
        importance_type: str = "split",
1506
        client: Optional[Client] = None,
1507
        **kwargs: Any,
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
    ):
        """Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
        self.client = client
        super().__init__(
            boosting_type=boosting_type,
            num_leaves=num_leaves,
            max_depth=max_depth,
            learning_rate=learning_rate,
            n_estimators=n_estimators,
            subsample_for_bin=subsample_for_bin,
            objective=objective,
            class_weight=class_weight,
            min_split_gain=min_split_gain,
            min_child_weight=min_child_weight,
            min_child_samples=min_child_samples,
            subsample=subsample,
            subsample_freq=subsample_freq,
            colsample_bytree=colsample_bytree,
            reg_alpha=reg_alpha,
            reg_lambda=reg_lambda,
            random_state=random_state,
            n_jobs=n_jobs,
            importance_type=importance_type,
1531
            **kwargs,
1532
1533
1534
        )

    _base_doc = LGBMRanker.__init__.__doc__
1535
    _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs")  # type: ignore
1536
    __init__.__doc__ = f"""
1537
1538
1539
1540
        {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
        {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
        {_kwargs}{_after_kwargs}
        """
1541
1542

    def __getstate__(self) -> Dict[Any, Any]:
1543
        return self._lgb_dask_getstate()
1544

1545
    def fit(  # type: ignore[override]
1546
1547
1548
        self,
        X: _DaskMatrixLike,
        y: _DaskCollection,
1549
1550
1551
        sample_weight: Optional[_DaskVectorLike] = None,
        init_score: Optional[_DaskVectorLike] = None,
        group: Optional[_DaskVectorLike] = None,
1552
1553
        eval_set: Optional[List[Tuple[_DaskMatrixLike, _DaskCollection]]] = None,
        eval_names: Optional[List[str]] = None,
1554
1555
1556
        eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
        eval_init_score: Optional[List[_DaskVectorLike]] = None,
        eval_group: Optional[List[_DaskVectorLike]] = None,
1557
        eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
1558
        eval_at: Union[List[int], Tuple[int, ...]] = (1, 2, 3, 4, 5),
1559
        **kwargs: Any,
1560
    ) -> "DaskLGBMRanker":
1561
        """Docstring is inherited from the lightgbm.LGBMRanker.fit."""
1562
        self._lgb_dask_fit(
1563
1564
1565
1566
            model_factory=LGBMRanker,
            X=X,
            y=y,
            sample_weight=sample_weight,
1567
            init_score=init_score,
1568
            group=group,
1569
1570
1571
1572
1573
1574
1575
            eval_set=eval_set,
            eval_names=eval_names,
            eval_sample_weight=eval_sample_weight,
            eval_init_score=eval_init_score,
            eval_group=eval_group,
            eval_metric=eval_metric,
            eval_at=eval_at,
1576
            **kwargs,
1577
        )
1578
        return self
1579

1580
1581
1582
    _base_doc = _lgbmmodel_doc_fit.format(
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        y_shape="Dask Array, Dask DataFrame or Dask Series of shape = [n_samples]",
1583
1584
        sample_weight_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
        init_score_shape="Dask Array or Dask Series of shape = [n_samples] or None, optional (default=None)",
1585
        group_shape="Dask Array or Dask Series or None, optional (default=None)",
1586
1587
        eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
        eval_init_score_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1588
        eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
1589
1590
    )

1591
    # DaskLGBMRanker does not support eval_class_weight or early stopping
1592
    _base_doc = _base_doc[: _base_doc.find("eval_class_weight :")] + _base_doc[_base_doc.find("eval_init_score :") :]
1593

1594
1595
1596
1597
1598
1599
    _base_doc = (
        _base_doc[: _base_doc.find("feature_name :")]
        + "eval_at : list or tuple of int, optional (default=(1, 2, 3, 4, 5))\n"
        + f"{' ':8}The evaluation positions of the specified metric.\n"
        + f"{' ':4}{_base_doc[_base_doc.find('feature_name :'):]}"
    )
1600
1601

    # DaskLGBMRanker support for callbacks and init_model is not tested
1602
1603
    fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs
        Other parameters passed through to ``LGBMRanker.fit()``.
1604

1605
1606
1607
1608
1609
    Returns
    -------
    self : lightgbm.DaskLGBMRanker
        Returns self.

1610
    {_lgbmmodel_doc_custom_eval_note}
1611
        """
1612

1613
1614
    def predict(
        self,
1615
        X: _DaskMatrixLike,  # type: ignore[override]
1616
1617
1618
1619
1620
1621
        raw_score: bool = False,
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        validate_features: bool = False,
1622
        **kwargs: Any,
1623
    ) -> dask_Array:
1624
        """Docstring is inherited from the lightgbm.LGBMRanker.predict."""
1625
1626
1627
1628
        return _predict(
            model=self.to_local(),
            data=X,
            client=_get_dask_client(self.client),
1629
1630
1631
1632
1633
1634
            raw_score=raw_score,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            validate_features=validate_features,
1635
            **kwargs,
1636
        )
1637

1638
1639
1640
1641
1642
1643
    predict.__doc__ = _lgbmmodel_doc_predict.format(
        description="Return the predicted value for each sample.",
        X_shape="Dask Array or Dask DataFrame of shape = [n_samples, n_features]",
        output_name="predicted_result",
        predicted_result_shape="Dask Array of shape = [n_samples]",
        X_leaves_shape="Dask Array of shape = [n_samples, n_trees]",
1644
        X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]",
1645
    )
1646

1647
    def to_local(self) -> LGBMRanker:
1648
1649
1650
1651
1652
        """Create regular version of lightgbm.LGBMRanker from the distributed version.

        Returns
        -------
        model : lightgbm.LGBMRanker
1653
            Local underlying model.
1654
        """
1655
        return self._lgb_dask_to_local(LGBMRanker)