basic.py 199 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Wrapper for C API of LightGBM."""
3

4
5
6
7
8
9
10
# This import causes lib_lightgbm.{dll,dylib,so} to be loaded.
# It's intentionally done here, as early as possible, to avoid issues like
# "libgomp.so.1: cannot allocate memory in static TLS block" on aarch64 Linux.
#
# For details, see the "cannot allocate memory in static TLS block" entry in docs/FAQ.rst.
from .libpath import _LIB  # isort: skip

11
import abc
wxchan's avatar
wxchan committed
12
import ctypes
13
import inspect
14
import json
wxchan's avatar
wxchan committed
15
import warnings
16
from collections import OrderedDict
17
from copy import deepcopy
18
from enum import Enum
19
from functools import wraps
20
from os import SEEK_END, environ
21
22
from os.path import getsize
from pathlib import Path
23
from tempfile import NamedTemporaryFile
24
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
wxchan's avatar
wxchan committed
25
26
27
28

import numpy as np
import scipy.sparse

29
30
31
32
from .compat import (
    PANDAS_INSTALLED,
    PYARROW_INSTALLED,
    arrow_cffi,
33
    arrow_is_boolean,
34
35
36
37
38
39
40
41
42
43
44
45
46
    arrow_is_floating,
    arrow_is_integer,
    concat,
    dt_DataTable,
    pa_Array,
    pa_chunked_array,
    pa_ChunkedArray,
    pa_compute,
    pa_Table,
    pd_CategoricalDtype,
    pd_DataFrame,
    pd_Series,
)
wxchan's avatar
wxchan committed
47

48
49
50
if TYPE_CHECKING:
    from typing import Literal

51
52
53
54
55
56
57
    # typing.TypeGuard was only introduced in Python 3.10
    try:
        from typing import TypeGuard
    except ImportError:
        from typing_extensions import TypeGuard


58
__all__ = [
59
60
61
62
63
64
    "Booster",
    "Dataset",
    "LGBMDeprecationWarning",
    "LightGBMError",
    "register_logger",
    "Sequence",
65
66
]

67
_BoosterHandle = ctypes.c_void_p
68
_DatasetHandle = ctypes.c_void_p
69
70
_ctypes_int_ptr = Union[
    "ctypes._Pointer[ctypes.c_int32]",
71
    "ctypes._Pointer[ctypes.c_int64]",
72
]
73
74
_ctypes_int_array = Union[
    "ctypes.Array[ctypes._Pointer[ctypes.c_int32]]",
75
    "ctypes.Array[ctypes._Pointer[ctypes.c_int64]]",
76
]
77
78
_ctypes_float_ptr = Union[
    "ctypes._Pointer[ctypes.c_float]",
79
    "ctypes._Pointer[ctypes.c_double]",
80
81
82
]
_ctypes_float_array = Union[
    "ctypes.Array[ctypes._Pointer[ctypes.c_float]]",
83
    "ctypes.Array[ctypes._Pointer[ctypes.c_double]]",
84
]
85
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
86
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
87
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
88
_LGBM_BoosterEvalMethodResultWithStandardDeviationType = Tuple[str, str, float, bool, float]
89
90
_LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], "Literal['auto']"]
_LGBM_FeatureNameConfiguration = Union[List[str], "Literal['auto']"]
91
92
93
94
_LGBM_GroupType = Union[
    List[float],
    List[int],
    np.ndarray,
95
96
97
    pd_Series,
    pa_Array,
    pa_ChunkedArray,
98
]
99
100
_LGBM_PositionType = Union[
    np.ndarray,
101
    pd_Series,
102
]
103
104
105
106
107
108
_LGBM_InitScoreType = Union[
    List[float],
    List[List[float]],
    np.ndarray,
    pd_Series,
    pd_DataFrame,
109
110
111
    pa_Table,
    pa_Array,
    pa_ChunkedArray,
112
]
113
114
115
116
117
118
119
120
121
_LGBM_TrainDataType = Union[
    str,
    Path,
    np.ndarray,
    pd_DataFrame,
    dt_DataTable,
    scipy.sparse.spmatrix,
    "Sequence",
    List["Sequence"],
122
    List[np.ndarray],
123
    pa_Table,
124
]
125
_LGBM_LabelType = Union[
126
127
    List[float],
    List[int],
128
129
    np.ndarray,
    pd_Series,
130
131
132
    pd_DataFrame,
    pa_Array,
    pa_ChunkedArray,
133
]
134
135
136
137
138
139
_LGBM_PredictDataType = Union[
    str,
    Path,
    np.ndarray,
    pd_DataFrame,
    dt_DataTable,
140
141
    scipy.sparse.spmatrix,
    pa_Table,
142
]
143
144
145
146
_LGBM_WeightType = Union[
    List[float],
    List[int],
    np.ndarray,
147
148
149
    pd_Series,
    pa_Array,
    pa_ChunkedArray,
150
]
151
152
153
154
155
156
157
158
159
160
161
162
163
_LGBM_SetFieldType = Union[
    List[List[float]],
    List[List[int]],
    List[float],
    List[int],
    np.ndarray,
    pd_Series,
    pd_DataFrame,
    pa_Table,
    pa_Array,
    pa_ChunkedArray,
]

164
165
ZERO_THRESHOLD = 1e-35

166
167
_MULTICLASS_OBJECTIVES = {"multiclass", "multiclassova", "multiclass_ova", "ova", "ovr", "softmax"}

168

169
170
171
172
173
174
class LightGBMError(Exception):
    """Error thrown by LightGBM."""

    pass


175
176
177
178
def _is_zero(x: float) -> bool:
    return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD


179
def _get_sample_count(total_nrow: int, params: str) -> int:
180
    sample_cnt = ctypes.c_int(0)
181
182
183
184
185
186
187
    _safe_call(
        _LIB.LGBM_GetSampleCount(
            ctypes.c_int32(total_nrow),
            _c_str(params),
            ctypes.byref(sample_cnt),
        )
    )
188
189
    return sample_cnt.value

wxchan's avatar
wxchan committed
190

191
class _MissingType(Enum):
192
193
194
    NONE = "None"
    NAN = "NaN"
    ZERO = "Zero"
195
196


197
class _DummyLogger:
198
    def info(self, msg: str) -> None:
199
        print(msg)  # noqa: T201
200

201
    def warning(self, msg: str) -> None:
202
203
204
        warnings.warn(msg, stacklevel=3)


205
206
207
_LOGGER: Any = _DummyLogger()
_INFO_METHOD_NAME = "info"
_WARNING_METHOD_NAME = "warning"
208
209


210
211
212
213
def _has_method(logger: Any, method_name: str) -> bool:
    return callable(getattr(logger, method_name, None))


214
def register_logger(
215
216
217
    logger: Any,
    info_method_name: str = "info",
    warning_method_name: str = "warning",
218
) -> None:
219
220
221
222
    """Register custom logger.

    Parameters
    ----------
223
    logger : Any
224
        Custom logger.
225
226
227
228
    info_method_name : str, optional (default="info")
        Method used to log info messages.
    warning_method_name : str, optional (default="warning")
        Method used to log warning messages.
229
    """
230
    if not _has_method(logger, info_method_name) or not _has_method(logger, warning_method_name):
231
        raise TypeError(f"Logger must provide '{info_method_name}' and '{warning_method_name}' method")
232
233

    global _LOGGER, _INFO_METHOD_NAME, _WARNING_METHOD_NAME
234
    _LOGGER = logger
235
236
    _INFO_METHOD_NAME = info_method_name
    _WARNING_METHOD_NAME = warning_method_name
237
238


239
def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], None]:
240
    """Join log messages from native library which come by chunks."""
241
    msg_normalized: List[str] = []
242
243

    @wraps(func)
244
    def wrapper(msg: str) -> None:
245
        nonlocal msg_normalized
246
247
        if msg.strip() == "":
            msg = "".join(msg_normalized)
248
249
250
251
252
253
254
255
            msg_normalized = []
            return func(msg)
        else:
            msg_normalized.append(msg)

    return wrapper


256
def _log_info(msg: str) -> None:
257
    getattr(_LOGGER, _INFO_METHOD_NAME)(msg)
258
259


260
def _log_warning(msg: str) -> None:
261
    getattr(_LOGGER, _WARNING_METHOD_NAME)(msg)
262
263
264


@_normalize_native_string
265
def _log_native(msg: str) -> None:
266
    getattr(_LOGGER, _INFO_METHOD_NAME)(msg)
267
268


269
def _log_callback(msg: bytes) -> None:
270
    """Redirect logs from native library into Python."""
271
    _log_native(str(msg.decode("utf-8")))
272
273


274
275
276
# connect the Python logger to logging in lib_lightgbm
if not environ.get("LIGHTGBM_BUILD_DOC", False):
    _LIB.LGBM_GetLastError.restype = ctypes.c_char_p
277
    callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
278
279
280
    _LIB.callback = callback(_log_callback)  # type: ignore[attr-defined]
    if _LIB.LGBM_RegisterLogCallback(_LIB.callback) != 0:
        raise LightGBMError(_LIB.LGBM_GetLastError().decode("utf-8"))
wxchan's avatar
wxchan committed
281

wxchan's avatar
wxchan committed
282

283
_NUMERIC_TYPES = (int, float, bool)
284
285


286
def _safe_call(ret: int) -> None:
287
288
    """Check the return value from C API call.

wxchan's avatar
wxchan committed
289
290
291
    Parameters
    ----------
    ret : int
292
        The return value from C API calls.
wxchan's avatar
wxchan committed
293
294
    """
    if ret != 0:
295
        raise LightGBMError(_LIB.LGBM_GetLastError().decode("utf-8"))
wxchan's avatar
wxchan committed
296

wxchan's avatar
wxchan committed
297

298
def _is_numeric(obj: Any) -> bool:
299
    """Check whether object is a number or not, include numpy number, etc."""
wxchan's avatar
wxchan committed
300
301
302
    try:
        float(obj)
        return True
wxchan's avatar
wxchan committed
303
304
305
    except (TypeError, ValueError):
        # TypeError: obj is not a string or a number
        # ValueError: invalid literal
wxchan's avatar
wxchan committed
306
307
        return False

wxchan's avatar
wxchan committed
308

309
def _is_numpy_1d_array(data: Any) -> bool:
310
    """Check whether data is a numpy 1-D array."""
311
    return isinstance(data, np.ndarray) and len(data.shape) == 1
wxchan's avatar
wxchan committed
312

wxchan's avatar
wxchan committed
313

314
def _is_numpy_column_array(data: Any) -> bool:
315
316
317
318
319
320
321
    """Check whether data is a column numpy array."""
    if not isinstance(data, np.ndarray):
        return False
    shape = data.shape
    return len(shape) == 2 and shape[1] == 1


322
def _cast_numpy_array_to_dtype(array: np.ndarray, dtype: "np.typing.DTypeLike") -> np.ndarray:
323
    """Cast numpy array to given dtype."""
324
325
326
327
328
    if array.dtype == dtype:
        return array
    return array.astype(dtype=dtype, copy=False)


329
def _is_1d_list(data: Any) -> bool:
330
    """Check whether data is a 1-D list."""
331
    return isinstance(data, list) and (not data or _is_numeric(data[0]))
wxchan's avatar
wxchan committed
332

wxchan's avatar
wxchan committed
333

334
def _is_list_of_numpy_arrays(data: Any) -> "TypeGuard[List[np.ndarray]]":
335
    return isinstance(data, list) and all(isinstance(x, np.ndarray) for x in data)
336
337
338


def _is_list_of_sequences(data: Any) -> "TypeGuard[List[Sequence]]":
339
    return isinstance(data, list) and all(isinstance(x, Sequence) for x in data)
340
341


342
343
def _is_1d_collection(data: Any) -> bool:
    """Check whether data is a 1-D collection."""
344
    return _is_numpy_1d_array(data) or _is_numpy_column_array(data) or _is_1d_list(data) or isinstance(data, pd_Series)
345
346


347
348
def _list_to_1d_numpy(
    data: Any,
349
    dtype: "np.typing.DTypeLike",
350
    name: str,
351
) -> np.ndarray:
352
    """Convert data to numpy 1-D array."""
353
    if _is_numpy_1d_array(data):
354
        return _cast_numpy_array_to_dtype(data, dtype)
355
    elif _is_numpy_column_array(data):
356
        _log_warning("Converting column-vector to 1d array")
357
        array = data.ravel()
358
        return _cast_numpy_array_to_dtype(array, dtype)
359
    elif _is_1d_list(data):
360
        return np.asarray(data, dtype=dtype)
361
    elif isinstance(data, pd_Series):
362
        _check_for_bad_pandas_dtypes(data.to_frame().dtypes)
363
        return np.asarray(data, dtype=dtype)  # SparseArray should be supported as well
wxchan's avatar
wxchan committed
364
    else:
365
366
367
        raise TypeError(
            f"Wrong type({type(data).__name__}) for {name}.\n" "It should be list, numpy 1-D array or pandas Series"
        )
wxchan's avatar
wxchan committed
368

wxchan's avatar
wxchan committed
369

370
371
372
373
374
375
376
def _is_numpy_2d_array(data: Any) -> bool:
    """Check whether data is a numpy 2-D array."""
    return isinstance(data, np.ndarray) and len(data.shape) == 2 and data.shape[1] > 1


def _is_2d_list(data: Any) -> bool:
    """Check whether data is a 2-D list."""
377
    return isinstance(data, list) and len(data) > 0 and _is_1d_list(data[0])
378
379
380
381


def _is_2d_collection(data: Any) -> bool:
    """Check whether data is a 2-D collection."""
382
    return _is_numpy_2d_array(data) or _is_2d_list(data) or isinstance(data, pd_DataFrame)
383
384


385
def _is_pyarrow_array(data: Any) -> "TypeGuard[Union[pa_Array, pa_ChunkedArray]]":
386
387
388
389
    """Check whether data is a PyArrow array."""
    return isinstance(data, (pa_Array, pa_ChunkedArray))


390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def _is_pyarrow_table(data: Any) -> bool:
    """Check whether data is a PyArrow table."""
    return isinstance(data, pa_Table)


class _ArrowCArray:
    """Simple wrapper around the C representation of an Arrow type."""

    n_chunks: int
    chunks: arrow_cffi.CData
    schema: arrow_cffi.CData

    def __init__(self, n_chunks: int, chunks: arrow_cffi.CData, schema: arrow_cffi.CData):
        self.n_chunks = n_chunks
        self.chunks = chunks
        self.schema = schema

    @property
    def chunks_ptr(self) -> int:
        """Returns the address of the pointer to the list of chunks making up the array."""
        return int(arrow_cffi.cast("uintptr_t", arrow_cffi.addressof(self.chunks[0])))

    @property
    def schema_ptr(self) -> int:
        """Returns the address of the pointer to the schema of the array."""
        return int(arrow_cffi.cast("uintptr_t", self.schema))


def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray:
    """Export an Arrow type to its C representation."""
    # Obtain objects to export
421
422
423
424
425
    if isinstance(data, pa_Array):
        export_objects = [data]
    elif isinstance(data, pa_ChunkedArray):
        export_objects = data.chunks
    elif isinstance(data, pa_Table):
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        export_objects = data.to_batches()
    else:
        raise ValueError(f"data of type '{type(data)}' cannot be exported to Arrow")

    # Prepare export
    chunks = arrow_cffi.new("struct ArrowArray[]", len(export_objects))
    schema = arrow_cffi.new("struct ArrowSchema*")

    # Export all objects
    for i, obj in enumerate(export_objects):
        chunk_ptr = int(arrow_cffi.cast("uintptr_t", arrow_cffi.addressof(chunks[i])))
        if i == 0:
            schema_ptr = int(arrow_cffi.cast("uintptr_t", schema))
            obj._export_to_c(chunk_ptr, schema_ptr)
        else:
            obj._export_to_c(chunk_ptr)

    return _ArrowCArray(len(chunks), chunks, schema)


446
447
def _data_to_2d_numpy(
    data: Any,
448
    dtype: "np.typing.DTypeLike",
449
    name: str,
450
) -> np.ndarray:
451
452
    """Convert data to numpy 2-D array."""
    if _is_numpy_2d_array(data):
453
        return _cast_numpy_array_to_dtype(data, dtype)
454
455
456
    if _is_2d_list(data):
        return np.array(data, dtype=dtype)
    if isinstance(data, pd_DataFrame):
457
        _check_for_bad_pandas_dtypes(data.dtypes)
458
        return _cast_numpy_array_to_dtype(data.values, dtype)
459
460
461
462
    raise TypeError(
        f"Wrong type({type(data).__name__}) for {name}.\n"
        "It should be list of lists, numpy 2-D array or pandas DataFrame"
    )
463
464


465
def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
466
    """Convert a ctypes float pointer array to a numpy array."""
wxchan's avatar
wxchan committed
467
    if isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
468
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
wxchan's avatar
wxchan committed
469
    else:
470
        raise RuntimeError("Expected float pointer")
wxchan's avatar
wxchan committed
471

Guolin Ke's avatar
Guolin Ke committed
472

473
def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
474
    """Convert a ctypes double pointer array to a numpy array."""
Guolin Ke's avatar
Guolin Ke committed
475
    if isinstance(cptr, ctypes.POINTER(ctypes.c_double)):
476
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
Guolin Ke's avatar
Guolin Ke committed
477
    else:
478
        raise RuntimeError("Expected double pointer")
Guolin Ke's avatar
Guolin Ke committed
479

wxchan's avatar
wxchan committed
480

481
def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
482
    """Convert a ctypes int pointer array to a numpy array."""
wxchan's avatar
wxchan committed
483
    if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)):
484
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
wxchan's avatar
wxchan committed
485
    else:
486
        raise RuntimeError("Expected int32 pointer")
487
488


489
def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
490
491
    """Convert a ctypes int pointer array to a numpy array."""
    if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)):
492
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
493
    else:
494
        raise RuntimeError("Expected int64 pointer")
wxchan's avatar
wxchan committed
495

wxchan's avatar
wxchan committed
496

497
def _c_str(string: str) -> ctypes.c_char_p:
498
    """Convert a Python string to C string."""
499
    return ctypes.c_char_p(string.encode("utf-8"))
wxchan's avatar
wxchan committed
500

wxchan's avatar
wxchan committed
501

502
def _c_array(ctype: type, values: List[Any]) -> ctypes.Array:
503
    """Convert a Python array to C array."""
504
    return (ctype * len(values))(*values)  # type: ignore[operator]
wxchan's avatar
wxchan committed
505

wxchan's avatar
wxchan committed
506

507
def _json_default_with_numpy(obj: Any) -> Any:
508
509
510
511
512
513
514
515
516
    """Convert numpy classes to JSON serializable objects."""
    if isinstance(obj, (np.integer, np.floating, np.bool_)):
        return obj.item()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


517
518
519
520
521
522
523
524
def _to_string(x: Union[int, float, str, List]) -> str:
    if isinstance(x, list):
        val_list = ",".join(str(val) for val in x)
        return f"[{val_list}]"
    else:
        return str(x)


525
def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str:
526
    """Convert Python dictionary to string, which is passed to C API."""
527
    if data is None or not data:
wxchan's avatar
wxchan committed
528
529
530
        return ""
    pairs = []
    for key, val in data.items():
531
        if isinstance(val, (list, tuple, set)) or _is_numpy_1d_array(val):
532
            pairs.append(f"{key}={','.join(map(_to_string, val))}")
533
        elif isinstance(val, (str, Path, _NUMERIC_TYPES)) or _is_numeric(val):
534
            pairs.append(f"{key}={val}")
535
        elif val is not None:
536
537
            raise TypeError(f"Unknown type of parameter:{key}, got:{type(val).__name__}")
    return " ".join(pairs)
538

wxchan's avatar
wxchan committed
539

540
class _TempFile:
541
542
    """Proxy class to workaround errors on Windows."""

543
    def __enter__(self) -> "_TempFile":
544
545
        with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f:
            self.name = f.name
546
            self.path = Path(self.name)
547
        return self
wxchan's avatar
wxchan committed
548

549
    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
550
551
        if self.path.is_file():
            self.path.unlink()
552

wxchan's avatar
wxchan committed
553

554
# DeprecationWarning is not shown by default, so let's create our own with higher level
555
556
# ref: https://peps.python.org/pep-0565/#additional-use-case-for-futurewarning
class LGBMDeprecationWarning(FutureWarning):
557
558
559
560
561
562
    """Custom deprecation warning."""

    pass


class _ConfigAliases:
563
564
565
566
    # lazy evaluation to allow import without dynamic library, e.g., for docs generation
    aliases = None

    @staticmethod
567
    def _get_all_param_aliases() -> Dict[str, List[str]]:
568
569
570
        buffer_len = 1 << 20
        tmp_out_len = ctypes.c_int64(0)
        string_buffer = ctypes.create_string_buffer(buffer_len)
571
        ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
572
573
574
575
576
577
578
        _safe_call(
            _LIB.LGBM_DumpParamAliases(
                ctypes.c_int64(buffer_len),
                ctypes.byref(tmp_out_len),
                ptr_string_buffer,
            )
        )
579
580
581
582
        actual_len = tmp_out_len.value
        # if buffer length is not long enough, re-allocate a buffer
        if actual_len > buffer_len:
            string_buffer = ctypes.create_string_buffer(actual_len)
583
            ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
584
585
586
587
588
589
590
            _safe_call(
                _LIB.LGBM_DumpParamAliases(
                    ctypes.c_int64(actual_len),
                    ctypes.byref(tmp_out_len),
                    ptr_string_buffer,
                )
            )
591
        return json.loads(
592
            string_buffer.value.decode("utf-8"), object_hook=lambda obj: {k: [k] + v for k, v in obj.items()}
593
        )
594
595

    @classmethod
596
    def get(cls, *args: str) -> Set[str]:
597
598
        if cls.aliases is None:
            cls.aliases = cls._get_all_param_aliases()
599
600
        ret = set()
        for i in args:
601
            ret.update(cls.get_sorted(i))
602
603
        return ret

604
605
606
607
608
609
    @classmethod
    def get_sorted(cls, name: str) -> List[str]:
        if cls.aliases is None:
            cls.aliases = cls._get_all_param_aliases()
        return cls.aliases.get(name, [name])

610
    @classmethod
611
    def get_by_alias(cls, *args: str) -> Set[str]:
612
613
        if cls.aliases is None:
            cls.aliases = cls._get_all_param_aliases()
614
615
616
617
        ret = set(args)
        for arg in args:
            for aliases in cls.aliases.values():
                if arg in aliases:
618
                    ret.update(aliases)
619
620
621
                    break
        return ret

622

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any) -> Dict[str, Any]:
    """Get a single parameter value, accounting for aliases.

    Parameters
    ----------
    main_param_name : str
        Name of the main parameter to get a value for. One of the keys of ``_ConfigAliases``.
    params : dict
        Dictionary of LightGBM parameters.
    default_value : Any
        Default value to use for the parameter, if none is found in ``params``.

    Returns
    -------
    params : dict
        A ``params`` dict with exactly one value for ``main_param_name``, and all aliases ``main_param_name`` removed.
        If both ``main_param_name`` and one or more aliases for it are found, the value of ``main_param_name`` will be preferred.
    """
    # avoid side effects on passed-in parameters
    params = deepcopy(params)

644
645
    aliases = _ConfigAliases.get_sorted(main_param_name)
    aliases = [a for a in aliases if a != main_param_name]
646
647

    # if main_param_name was provided, keep that value and remove all aliases
648
    if main_param_name in params.keys():
649
650
651
        for param in aliases:
            params.pop(param, None)
        return params
652

653
654
655
656
657
    # if main param name was not found, search for an alias
    for param in aliases:
        if param in params.keys():
            params[main_param_name] = params[param]
            break
658

659
660
661
662
663
664
665
    if main_param_name in params.keys():
        for param in aliases:
            params.pop(param, None)
        return params

    # neither of main_param_name, aliases were found
    params[main_param_name] = default_value
666
667
668
669

    return params


670
_MAX_INT32 = (1 << 31) - 1
671

672
"""Macro definition of data type in C API of LightGBM"""
673
674
675
676
_C_API_DTYPE_FLOAT32 = 0
_C_API_DTYPE_FLOAT64 = 1
_C_API_DTYPE_INT32 = 2
_C_API_DTYPE_INT64 = 3
Guolin Ke's avatar
Guolin Ke committed
677

678
"""Matrix is row major in Python"""
679
_C_API_IS_ROW_MAJOR = 1
wxchan's avatar
wxchan committed
680

681
"""Macro definition of prediction type in C API of LightGBM"""
682
683
684
685
_C_API_PREDICT_NORMAL = 0
_C_API_PREDICT_RAW_SCORE = 1
_C_API_PREDICT_LEAF_INDEX = 2
_C_API_PREDICT_CONTRIB = 3
wxchan's avatar
wxchan committed
686

687
"""Macro definition of sparse matrix type"""
688
689
_C_API_MATRIX_TYPE_CSR = 0
_C_API_MATRIX_TYPE_CSC = 1
690

691
"""Macro definition of feature importance type"""
692
693
_C_API_FEATURE_IMPORTANCE_SPLIT = 0
_C_API_FEATURE_IMPORTANCE_GAIN = 1
694

695
"""Data type of data field"""
696
697
698
699
_FIELD_TYPE_MAPPER = {
    "label": _C_API_DTYPE_FLOAT32,
    "weight": _C_API_DTYPE_FLOAT32,
    "init_score": _C_API_DTYPE_FLOAT64,
700
    "group": _C_API_DTYPE_INT32,
701
    "position": _C_API_DTYPE_INT32,
702
}
wxchan's avatar
wxchan committed
703

704
"""String name to int feature importance type mapper"""
705
706
_FEATURE_IMPORTANCE_TYPE_MAPPER = {
    "split": _C_API_FEATURE_IMPORTANCE_SPLIT,
707
    "gain": _C_API_FEATURE_IMPORTANCE_GAIN,
708
}
709

wxchan's avatar
wxchan committed
710

711
def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray:
712
    """Fix the memory of multi-dimensional sliced object."""
713
    if isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
714
        if not data.flags.c_contiguous:
715
716
717
718
            _log_warning(
                "Usage of np.ndarray subset (sliced data) is not recommended "
                "due to it will double the peak memory cost in LightGBM."
            )
719
720
721
722
            return np.copy(data)
    return data


723
def _c_float_array(data: np.ndarray) -> Tuple[_ctypes_float_ptr, int, np.ndarray]:
724
    """Get pointer of float numpy array / list."""
725
    if _is_1d_list(data):
726
        data = np.asarray(data)
727
    if _is_numpy_1d_array(data):
728
        data = _convert_from_sliced_object(data)
729
        assert data.flags.c_contiguous
730
        ptr_data: _ctypes_float_ptr
wxchan's avatar
wxchan committed
731
732
        if data.dtype == np.float32:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
733
            type_data = _C_API_DTYPE_FLOAT32
wxchan's avatar
wxchan committed
734
735
        elif data.dtype == np.float64:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
736
            type_data = _C_API_DTYPE_FLOAT64
wxchan's avatar
wxchan committed
737
        else:
738
            raise TypeError(f"Expected np.float32 or np.float64, met type({data.dtype})")
wxchan's avatar
wxchan committed
739
    else:
740
        raise TypeError(f"Unknown type({type(data).__name__})")
741
    return (ptr_data, type_data, data)  # return `data` to avoid the temporary copy is freed
wxchan's avatar
wxchan committed
742

wxchan's avatar
wxchan committed
743

744
def _c_int_array(data: np.ndarray) -> Tuple[_ctypes_int_ptr, int, np.ndarray]:
745
    """Get pointer of int numpy array / list."""
746
    if _is_1d_list(data):
747
        data = np.asarray(data)
748
    if _is_numpy_1d_array(data):
749
        data = _convert_from_sliced_object(data)
750
        assert data.flags.c_contiguous
751
        ptr_data: _ctypes_int_ptr
wxchan's avatar
wxchan committed
752
753
        if data.dtype == np.int32:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
754
            type_data = _C_API_DTYPE_INT32
wxchan's avatar
wxchan committed
755
756
        elif data.dtype == np.int64:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int64))
757
            type_data = _C_API_DTYPE_INT64
wxchan's avatar
wxchan committed
758
        else:
759
            raise TypeError(f"Expected np.int32 or np.int64, met type({data.dtype})")
wxchan's avatar
wxchan committed
760
    else:
761
        raise TypeError(f"Unknown type({type(data).__name__})")
762
    return (ptr_data, type_data, data)  # return `data` to avoid the temporary copy is freed
wxchan's avatar
wxchan committed
763

wxchan's avatar
wxchan committed
764

765
def _is_allowed_numpy_dtype(dtype: type) -> bool:
766
767
    float128 = getattr(np, "float128", type(None))
    return issubclass(dtype, (np.integer, np.floating, np.bool_)) and not issubclass(dtype, (np.timedelta64, float128))
768
769


770
def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
771
    bad_pandas_dtypes = [
772
        f"{column_name}: {pandas_dtype}"
773
        for column_name, pandas_dtype in pandas_dtypes_series.items()
774
        if not _is_allowed_numpy_dtype(pandas_dtype.type)
775
776
    ]
    if bad_pandas_dtypes:
777
778
779
780
        raise ValueError(
            'pandas dtypes must be int, float or bool.\n'
            f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}'
        )
781
782


783
784
def _pandas_to_numpy(
    data: pd_DataFrame,
785
    target_dtype: "np.typing.DTypeLike",
786
787
788
789
790
791
792
793
794
795
796
797
798
799
) -> np.ndarray:
    _check_for_bad_pandas_dtypes(data.dtypes)
    try:
        # most common case (no nullable dtypes)
        return data.to_numpy(dtype=target_dtype, copy=False)
    except TypeError:
        # 1.0 <= pd version < 1.1 and nullable dtypes, least common case
        # raises error because array is casted to type(pd.NA) and there's no na_value argument
        return data.astype(target_dtype, copy=False).values
    except ValueError:
        # data has nullable dtypes, but we can specify na_value argument and copy will be made
        return data.to_numpy(dtype=target_dtype, na_value=np.nan)


800
def _data_from_pandas(
801
802
803
    data: pd_DataFrame,
    feature_name: _LGBM_FeatureNameConfiguration,
    categorical_feature: _LGBM_CategoricalFeatureConfiguration,
804
    pandas_categorical: Optional[List[List]],
805
) -> Tuple[np.ndarray, List[str], Union[List[str], List[int]], List[List]]:
806
    if len(data.shape) != 2 or data.shape[0] < 1:
807
        raise ValueError("Input data must be 2 dimensional and non empty.")
808

809
810
811
812
    # take shallow copy in case we modify categorical columns
    # whole column modifications don't change the original df
    data = data.copy(deep=False)

813
    # determine feature names
814
    if feature_name == "auto":
815
816
817
818
        feature_name = [str(col) for col in data.columns]

    # determine categorical features
    cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)]
819
    cat_cols_not_ordered: List[str] = [col for col in cat_cols if not data[col].cat.ordered]
820
821
    if pandas_categorical is None:  # train dataset
        pandas_categorical = [list(data[col].cat.categories) for col in cat_cols]
822
    else:
823
        if len(cat_cols) != len(pandas_categorical):
824
            raise ValueError("train and valid dataset categorical_feature do not match.")
825
826
827
828
829
        for col, category in zip(cat_cols, pandas_categorical):
            if list(data[col].cat.categories) != list(category):
                data[col] = data[col].cat.set_categories(category)
    if len(cat_cols):  # cat_cols is list
        data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan})
830
831

    # use cat cols from DataFrame
832
    if categorical_feature == "auto":
833
834
835
        categorical_feature = cat_cols_not_ordered

    df_dtypes = [dtype.type for dtype in data.dtypes]
836
837
    # so that the target dtype considers floats
    df_dtypes.append(np.float32)
838
    target_dtype = np.result_type(*df_dtypes)
839
840
841
842
843

    return (
        _pandas_to_numpy(data, target_dtype=target_dtype),
        feature_name,
        categorical_feature,
844
        pandas_categorical,
845
    )
846
847


848
849
def _dump_pandas_categorical(
    pandas_categorical: Optional[List[List]],
850
    file_name: Optional[Union[str, Path]] = None,
851
) -> str:
852
    categorical_json = json.dumps(pandas_categorical, default=_json_default_with_numpy)
853
    pandas_str = f"\npandas_categorical:{categorical_json}\n"
854
    if file_name is not None:
855
        with open(file_name, "a") as f:
856
857
858
859
            f.write(pandas_str)
    return pandas_str


860
861
def _load_pandas_categorical(
    file_name: Optional[Union[str, Path]] = None,
862
    model_str: Optional[str] = None,
863
) -> Optional[List[List]]:
864
    pandas_key = "pandas_categorical:"
865
    offset = -len(pandas_key)
866
    if file_name is not None:
867
        max_offset = -getsize(file_name)
868
        with open(file_name, "rb") as f:
869
870
871
            while True:
                if offset < max_offset:
                    offset = max_offset
872
                f.seek(offset, SEEK_END)
873
874
875
876
                lines = f.readlines()
                if len(lines) >= 2:
                    break
                offset *= 2
877
        last_line = lines[-1].decode("utf-8").strip()
878
        if not last_line.startswith(pandas_key):
879
            last_line = lines[-2].decode("utf-8").strip()
880
    elif model_str is not None:
881
        idx = model_str.rfind("\n", 0, offset)
882
883
        last_line = model_str[idx:].strip()
    if last_line.startswith(pandas_key):
884
        return json.loads(last_line[len(pandas_key) :])
885
886
    else:
        return None
887
888


889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
class Sequence(abc.ABC):
    """
    Generic data access interface.

    Object should support the following operations:

    .. code-block::

        # Get total row number.
        >>> len(seq)
        # Random access by row index. Used for data sampling.
        >>> seq[10]
        # Range data access. Used to read data in batch when constructing Dataset.
        >>> seq[0:100]
        # Optionally specify batch_size to control range data read size.
        >>> seq.batch_size

    - With random access, **data sampling does not need to go through all data**.
    - With range data access, there's **no need to read all data into memory thus reduce memory usage**.

909
910
    .. versionadded:: 3.3.0

911
912
913
914
915
916
917
918
919
    Attributes
    ----------
    batch_size : int
        Default size of a batch.
    """

    batch_size = 4096  # Defaults to read 4K rows in each batch.

    @abc.abstractmethod
920
    def __getitem__(self, idx: Union[int, slice, List[int]]) -> np.ndarray:
921
922
923
924
925
926
927
        """Return data for given row index.

        A basic implementation should look like this:

        .. code-block:: python

            if isinstance(idx, numbers.Integral):
928
                return self._get_one_line(idx)
929
            elif isinstance(idx, slice):
930
931
                return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)])
            elif isinstance(idx, list):
932
                # Only required if using ``Dataset.subset()``.
933
                return np.array([self._get_one_line(i) for i in idx])
934
            else:
935
                raise TypeError(f"Sequence index must be integer, slice or list, got {type(idx).__name__}")
936
937
938

        Parameters
        ----------
939
        idx : int, slice[int], list[int]
940
941
942
943
            Item index.

        Returns
        -------
944
        result : numpy 1-D array or numpy 2-D array
945
            1-D array if idx is int, 2-D array if idx is slice or list.
946
947
948
949
950
951
952
953
954
        """
        raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()")

    @abc.abstractmethod
    def __len__(self) -> int:
        """Return row count of this sequence."""
        raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __len__()")


955
class _InnerPredictor:
956
957
958
959
960
    """_InnerPredictor of LightGBM.

    Not exposed to user.
    Used only for prediction, usually used for continued training.

Nikita Titov's avatar
Nikita Titov committed
961
962
963
    .. note::

        Can be converted from Booster, but cannot be converted to Booster.
Guolin Ke's avatar
Guolin Ke committed
964
    """
965

966
967
    def __init__(
        self,
968
969
970
        booster_handle: _BoosterHandle,
        pandas_categorical: Optional[List[List]],
        pred_parameter: Dict[str, Any],
971
        manage_handle: bool,
972
    ):
973
        """Initialize the _InnerPredictor.
wxchan's avatar
wxchan committed
974
975
976

        Parameters
        ----------
977
        booster_handle : object
978
            Handle of Booster.
979
980
981
982
        pandas_categorical : list of list, or None
            If provided, list of categories for ``pandas`` categorical columns.
            Where the ``i``th element of the list contains the categories for the ``i``th categorical feature.
        pred_parameter : dict
983
            Other parameters for the prediction.
984
985
        manage_handle : bool
            If ``True``, free the corresponding Booster on the C++ side when this Python object is deleted.
wxchan's avatar
wxchan committed
986
        """
987
988
989
990
991
992
993
994
        self._handle = booster_handle
        self.__is_manage_handle = manage_handle
        self.pandas_categorical = pandas_categorical
        self.pred_parameter = _param_dict_to_str(pred_parameter)

        out_num_class = ctypes.c_int(0)
        _safe_call(
            _LIB.LGBM_BoosterGetNumClasses(
995
                self._handle,
996
                ctypes.byref(out_num_class),
997
998
999
            )
        )
        self.num_class = out_num_class.value
wxchan's avatar
wxchan committed
1000

1001
1002
1003
1004
    @classmethod
    def from_booster(
        cls,
        booster: "Booster",
1005
        pred_parameter: Dict[str, Any],
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
    ) -> "_InnerPredictor":
        """Initialize an ``_InnerPredictor`` from a ``Booster``.

        Parameters
        ----------
        booster : Booster
            Booster.
        pred_parameter : dict
            Other parameters for the prediction.
        """
        out_cur_iter = ctypes.c_int(0)
        _safe_call(
            _LIB.LGBM_BoosterGetCurrentIteration(
                booster._handle,
1020
                ctypes.byref(out_cur_iter),
1021
1022
1023
1024
1025
1026
            )
        )
        return cls(
            booster_handle=booster._handle,
            pandas_categorical=booster.pandas_categorical,
            pred_parameter=pred_parameter,
1027
            manage_handle=False,
1028
1029
1030
1031
1032
1033
        )

    @classmethod
    def from_model_file(
        cls,
        model_file: Union[str, Path],
1034
        pred_parameter: Dict[str, Any],
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
    ) -> "_InnerPredictor":
        """Initialize an ``_InnerPredictor`` from a text file containing a LightGBM model.

        Parameters
        ----------
        model_file : str or pathlib.Path
            Path to the model file.
        pred_parameter : dict
            Other parameters for the prediction.
        """
        booster_handle = ctypes.c_void_p()
        out_num_iterations = ctypes.c_int(0)
        _safe_call(
            _LIB.LGBM_BoosterCreateFromModelfile(
                _c_str(str(model_file)),
                ctypes.byref(out_num_iterations),
1051
                ctypes.byref(booster_handle),
1052
1053
1054
1055
1056
1057
            )
        )
        return cls(
            booster_handle=booster_handle,
            pandas_categorical=_load_pandas_categorical(file_name=model_file),
            pred_parameter=pred_parameter,
1058
            manage_handle=True,
1059
        )
cbecker's avatar
cbecker committed
1060

1061
    def __del__(self) -> None:
1062
1063
        try:
            if self.__is_manage_handle:
1064
                _safe_call(_LIB.LGBM_BoosterFree(self._handle))
1065
1066
        except AttributeError:
            pass
wxchan's avatar
wxchan committed
1067

1068
    def __getstate__(self) -> Dict[str, Any]:
1069
        this = self.__dict__.copy()
1070
1071
        this.pop("handle", None)
        this.pop("_handle", None)
1072
1073
        return this

1074
1075
    def predict(
        self,
1076
        data: _LGBM_PredictDataType,
1077
1078
1079
1080
1081
1082
        start_iteration: int = 0,
        num_iteration: int = -1,
        raw_score: bool = False,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        data_has_header: bool = False,
1083
        validate_features: bool = False,
1084
    ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
1085
        """Predict logic.
wxchan's avatar
wxchan committed
1086
1087
1088

        Parameters
        ----------
1089
        data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame or scipy.sparse
1090
            Data source for prediction.
1091
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
1092
1093
        start_iteration : int, optional (default=0)
            Start index of the iteration to predict.
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        num_iteration : int, optional (default=-1)
            Iteration used for prediction.
        raw_score : bool, optional (default=False)
            Whether to predict raw scores.
        pred_leaf : bool, optional (default=False)
            Whether to predict leaf index.
        pred_contrib : bool, optional (default=False)
            Whether to predict feature contributions.
        data_has_header : bool, optional (default=False)
            Whether data has header.
            Used only for txt data.
1105
1106
1107
        validate_features : bool, optional (default=False)
            If True, ensure that the features used to predict match the ones used to train.
            Used only if data is pandas DataFrame.
wxchan's avatar
wxchan committed
1108

1109
1110
            .. versionadded:: 4.0.0

wxchan's avatar
wxchan committed
1111
1112
        Returns
        -------
1113
        result : numpy array, scipy.sparse or list of scipy.sparse
1114
            Prediction result.
1115
            Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
wxchan's avatar
wxchan committed
1116
        """
wxchan's avatar
wxchan committed
1117
        if isinstance(data, Dataset):
1118
            raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
1119
1120
1121
        elif isinstance(data, pd_DataFrame) and validate_features:
            data_names = [str(x) for x in data.columns]
            ptr_names = (ctypes.c_char_p * len(data_names))()
1122
            ptr_names[:] = [x.encode("utf-8") for x in data_names]
1123
1124
            _safe_call(
                _LIB.LGBM_BoosterValidateFeatureNames(
1125
                    self._handle,
1126
1127
1128
1129
                    ptr_names,
                    ctypes.c_int(len(data_names)),
                )
            )
1130
1131
1132
1133
1134
1135

        if isinstance(data, pd_DataFrame):
            data = _data_from_pandas(
                data=data,
                feature_name="auto",
                categorical_feature="auto",
1136
                pandas_categorical=self.pandas_categorical,
1137
1138
            )[0]

1139
        predict_type = _C_API_PREDICT_NORMAL
wxchan's avatar
wxchan committed
1140
        if raw_score:
1141
            predict_type = _C_API_PREDICT_RAW_SCORE
wxchan's avatar
wxchan committed
1142
        if pred_leaf:
1143
            predict_type = _C_API_PREDICT_LEAF_INDEX
1144
        if pred_contrib:
1145
            predict_type = _C_API_PREDICT_CONTRIB
cbecker's avatar
cbecker committed
1146

1147
        if isinstance(data, (str, Path)):
1148
            with _TempFile() as f:
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
                _safe_call(
                    _LIB.LGBM_BoosterPredictForFile(
                        self._handle,
                        _c_str(str(data)),
                        ctypes.c_int(data_has_header),
                        ctypes.c_int(predict_type),
                        ctypes.c_int(start_iteration),
                        ctypes.c_int(num_iteration),
                        _c_str(self.pred_parameter),
                        _c_str(f.name),
                    )
                )
1161
1162
                preds = np.loadtxt(f.name, dtype=np.float64)
                nrow = preds.shape[0]
wxchan's avatar
wxchan committed
1163
        elif isinstance(data, scipy.sparse.csr_matrix):
1164
1165
1166
1167
            preds, nrow = self.__pred_for_csr(
                csr=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1168
                predict_type=predict_type,
1169
            )
Guolin Ke's avatar
Guolin Ke committed
1170
        elif isinstance(data, scipy.sparse.csc_matrix):
1171
1172
1173
1174
            preds, nrow = self.__pred_for_csc(
                csc=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1175
                predict_type=predict_type,
1176
            )
wxchan's avatar
wxchan committed
1177
        elif isinstance(data, np.ndarray):
1178
1179
1180
1181
            preds, nrow = self.__pred_for_np2d(
                mat=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1182
                predict_type=predict_type,
1183
            )
1184
1185
1186
1187
1188
        elif _is_pyarrow_table(data):
            preds, nrow = self.__pred_for_pyarrow_table(
                table=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1189
                predict_type=predict_type,
1190
            )
1191
1192
1193
        elif isinstance(data, list):
            try:
                data = np.array(data)
1194
            except BaseException as err:
1195
                raise ValueError("Cannot convert data list to numpy array.") from err
1196
1197
1198
1199
            preds, nrow = self.__pred_for_np2d(
                mat=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1200
                predict_type=predict_type,
1201
            )
1202
        elif isinstance(data, dt_DataTable):
1203
1204
1205
1206
            preds, nrow = self.__pred_for_np2d(
                mat=data.to_numpy(),
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1207
                predict_type=predict_type,
1208
            )
wxchan's avatar
wxchan committed
1209
1210
        else:
            try:
1211
                _log_warning("Converting data to scipy sparse matrix.")
wxchan's avatar
wxchan committed
1212
                csr = scipy.sparse.csr_matrix(data)
1213
            except BaseException as err:
1214
                raise TypeError(f"Cannot predict data for type {type(data).__name__}") from err
1215
1216
1217
1218
            preds, nrow = self.__pred_for_csr(
                csr=csr,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1219
                predict_type=predict_type,
1220
            )
wxchan's avatar
wxchan committed
1221
1222
        if pred_leaf:
            preds = preds.astype(np.int32)
1223
        is_sparse = isinstance(preds, (list, scipy.sparse.spmatrix))
1224
        if not is_sparse and preds.size != nrow:
wxchan's avatar
wxchan committed
1225
            if preds.size % nrow == 0:
1226
                preds = preds.reshape(nrow, -1)
wxchan's avatar
wxchan committed
1227
            else:
1228
                raise ValueError(f"Length of predict result ({preds.size}) cannot be divide nrow ({nrow})")
wxchan's avatar
wxchan committed
1229
1230
        return preds

1231
1232
1233
1234
1235
    def __get_num_preds(
        self,
        start_iteration: int,
        num_iteration: int,
        nrow: int,
1236
        predict_type: int,
1237
    ) -> int:
1238
        """Get size of prediction result."""
1239
        if nrow > _MAX_INT32:
1240
1241
1242
1243
1244
1245
            raise LightGBMError(
                "LightGBM cannot perform prediction for data "
                f"with number of rows greater than MAX_INT32 ({_MAX_INT32}).\n"
                "You can split your data into chunks "
                "and then concatenate predictions for them"
            )
Guolin Ke's avatar
Guolin Ke committed
1246
        n_preds = ctypes.c_int64(0)
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
        _safe_call(
            _LIB.LGBM_BoosterCalcNumPredict(
                self._handle,
                ctypes.c_int(nrow),
                ctypes.c_int(predict_type),
                ctypes.c_int(start_iteration),
                ctypes.c_int(num_iteration),
                ctypes.byref(n_preds),
            )
        )
Guolin Ke's avatar
Guolin Ke committed
1257
        return n_preds.value
wxchan's avatar
wxchan committed
1258

1259
1260
1261
1262
1263
1264
    def __inner_predict_np2d(
        self,
        mat: np.ndarray,
        start_iteration: int,
        num_iteration: int,
        predict_type: int,
1265
        preds: Optional[np.ndarray],
1266
1267
    ) -> Tuple[np.ndarray, int]:
        if mat.dtype == np.float32 or mat.dtype == np.float64:
1268
            data = np.asarray(mat.reshape(mat.size), dtype=mat.dtype)
1269
1270
1271
        else:  # change non-float data to float data, need to copy
            data = np.array(mat.reshape(mat.size), dtype=np.float32)
        ptr_data, type_ptr_data, _ = _c_float_array(data)
1272
1273
1274
1275
        n_preds = self.__get_num_preds(
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            nrow=mat.shape[0],
1276
            predict_type=predict_type,
1277
        )
1278
1279
1280
1281
1282
        if preds is None:
            preds = np.empty(n_preds, dtype=np.float64)
        elif len(preds.shape) != 1 or len(preds) != n_preds:
            raise ValueError("Wrong length of pre-allocated predict array")
        out_num_preds = ctypes.c_int64(0)
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
        _safe_call(
            _LIB.LGBM_BoosterPredictForMat(
                self._handle,
                ptr_data,
                ctypes.c_int(type_ptr_data),
                ctypes.c_int32(mat.shape[0]),
                ctypes.c_int32(mat.shape[1]),
                ctypes.c_int(_C_API_IS_ROW_MAJOR),
                ctypes.c_int(predict_type),
                ctypes.c_int(start_iteration),
                ctypes.c_int(num_iteration),
                _c_str(self.pred_parameter),
                ctypes.byref(out_num_preds),
                preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
            )
        )
1299
1300
1301
1302
1303
1304
1305
1306
1307
        if n_preds != out_num_preds.value:
            raise ValueError("Wrong length for predict results")
        return preds, mat.shape[0]

    def __pred_for_np2d(
        self,
        mat: np.ndarray,
        start_iteration: int,
        num_iteration: int,
1308
        predict_type: int,
1309
    ) -> Tuple[np.ndarray, int]:
1310
        """Predict for a 2-D numpy matrix."""
wxchan's avatar
wxchan committed
1311
        if len(mat.shape) != 2:
1312
            raise ValueError("Input numpy.ndarray or list must be 2 dimensional")
wxchan's avatar
wxchan committed
1313

1314
        nrow = mat.shape[0]
1315
1316
        if nrow > _MAX_INT32:
            sections = np.arange(start=_MAX_INT32, stop=nrow, step=_MAX_INT32)
1317
            # __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal
1318
1319
1320
1321
            n_preds = [
                self.__get_num_preds(start_iteration, num_iteration, i, predict_type)
                for i in np.diff([0] + list(sections) + [nrow])
            ]
1322
            n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
1323
            preds = np.empty(sum(n_preds), dtype=np.float64)
1324
1325
1326
            for chunk, (start_idx_pred, end_idx_pred) in zip(
                np.array_split(mat, sections), zip(n_preds_sections, n_preds_sections[1:])
            ):
1327
                # avoid memory consumption by arrays concatenation operations
1328
1329
1330
1331
1332
                self.__inner_predict_np2d(
                    mat=chunk,
                    start_iteration=start_iteration,
                    num_iteration=num_iteration,
                    predict_type=predict_type,
1333
                    preds=preds[start_idx_pred:end_idx_pred],
1334
                )
1335
            return preds, nrow
wxchan's avatar
wxchan committed
1336
        else:
1337
1338
1339
1340
1341
            return self.__inner_predict_np2d(
                mat=mat,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type,
1342
                preds=None,
1343
            )
wxchan's avatar
wxchan committed
1344

1345
1346
1347
    def __create_sparse_native(
        self,
        cs: Union[scipy.sparse.csc_matrix, scipy.sparse.csr_matrix],
1348
1349
1350
1351
1352
1353
        out_shape: np.ndarray,
        out_ptr_indptr: "ctypes._Pointer",
        out_ptr_indices: "ctypes._Pointer",
        out_ptr_data: "ctypes._Pointer",
        indptr_type: int,
        data_type: int,
1354
        is_csr: bool,
1355
    ) -> Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]]:
1356
1357
1358
        # create numpy array from output arrays
        data_indices_len = out_shape[0]
        indptr_len = out_shape[1]
1359
        if indptr_type == _C_API_DTYPE_INT32:
1360
            out_indptr = _cint32_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len)
1361
        elif indptr_type == _C_API_DTYPE_INT64:
1362
            out_indptr = _cint64_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len)
1363
1364
        else:
            raise TypeError("Expected int32 or int64 type for indptr")
1365
        if data_type == _C_API_DTYPE_FLOAT32:
1366
            out_data = _cfloat32_array_to_numpy(cptr=out_ptr_data, length=data_indices_len)
1367
        elif data_type == _C_API_DTYPE_FLOAT64:
1368
            out_data = _cfloat64_array_to_numpy(cptr=out_ptr_data, length=data_indices_len)
1369
1370
        else:
            raise TypeError("Expected float32 or float64 type for data")
1371
        out_indices = _cint32_array_to_numpy(cptr=out_ptr_indices, length=data_indices_len)
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
        # break up indptr based on number of rows (note more than one matrix in multiclass case)
        per_class_indptr_shape = cs.indptr.shape[0]
        # for CSC there is extra column added
        if not is_csr:
            per_class_indptr_shape += 1
        out_indptr_arrays = np.split(out_indptr, out_indptr.shape[0] / per_class_indptr_shape)
        # reformat output into a csr or csc matrix or list of csr or csc matrices
        cs_output_matrices = []
        offset = 0
        for cs_indptr in out_indptr_arrays:
            matrix_indptr_len = cs_indptr[cs_indptr.shape[0] - 1]
1383
1384
            cs_indices = out_indices[offset + cs_indptr[0] : offset + matrix_indptr_len]
            cs_data = out_data[offset + cs_indptr[0] : offset + matrix_indptr_len]
1385
1386
1387
1388
1389
1390
1391
1392
1393
            offset += matrix_indptr_len
            # same shape as input csr or csc matrix except extra column for expected value
            cs_shape = [cs.shape[0], cs.shape[1] + 1]
            # note: make sure we copy data as it will be deallocated next
            if is_csr:
                cs_output_matrices.append(scipy.sparse.csr_matrix((cs_data, cs_indices, cs_indptr), cs_shape))
            else:
                cs_output_matrices.append(scipy.sparse.csc_matrix((cs_data, cs_indices, cs_indptr), cs_shape))
        # free the temporary native indptr, indices, and data
1394
1395
1396
1397
1398
1399
1400
1401
1402
        _safe_call(
            _LIB.LGBM_BoosterFreePredictSparse(
                out_ptr_indptr,
                out_ptr_indices,
                out_ptr_data,
                ctypes.c_int(indptr_type),
                ctypes.c_int(data_type),
            )
        )
1403
1404
1405
1406
        if len(cs_output_matrices) == 1:
            return cs_output_matrices[0]
        return cs_output_matrices

1407
1408
1409
1410
1411
1412
    def __inner_predict_csr(
        self,
        csr: scipy.sparse.csr_matrix,
        start_iteration: int,
        num_iteration: int,
        predict_type: int,
1413
        preds: Optional[np.ndarray],
1414
1415
    ) -> Tuple[np.ndarray, int]:
        nrow = len(csr.indptr) - 1
1416
1417
1418
1419
        n_preds = self.__get_num_preds(
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            nrow=nrow,
1420
            predict_type=predict_type,
1421
        )
1422
1423
1424
1425
1426
        if preds is None:
            preds = np.empty(n_preds, dtype=np.float64)
        elif len(preds.shape) != 1 or len(preds) != n_preds:
            raise ValueError("Wrong length of pre-allocated predict array")
        out_num_preds = ctypes.c_int64(0)
wxchan's avatar
wxchan committed
1427

1428
1429
        ptr_indptr, type_ptr_indptr, _ = _c_int_array(csr.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
1430

1431
1432
1433
        assert csr.shape[1] <= _MAX_INT32
        csr_indices = csr.indices.astype(np.int32, copy=False)

1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        _safe_call(
            _LIB.LGBM_BoosterPredictForCSR(
                self._handle,
                ptr_indptr,
                ctypes.c_int(type_ptr_indptr),
                csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
                ptr_data,
                ctypes.c_int(type_ptr_data),
                ctypes.c_int64(len(csr.indptr)),
                ctypes.c_int64(len(csr.data)),
                ctypes.c_int64(csr.shape[1]),
                ctypes.c_int(predict_type),
                ctypes.c_int(start_iteration),
                ctypes.c_int(num_iteration),
                _c_str(self.pred_parameter),
                ctypes.byref(out_num_preds),
                preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
            )
        )
1453
1454
1455
1456
1457
1458
1459
1460
1461
        if n_preds != out_num_preds.value:
            raise ValueError("Wrong length for predict results")
        return preds, nrow

    def __inner_predict_csr_sparse(
        self,
        csr: scipy.sparse.csr_matrix,
        start_iteration: int,
        num_iteration: int,
1462
        predict_type: int,
1463
    ) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]:
1464
1465
1466
1467
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
        csr_indices = csr.indices.astype(np.int32, copy=False)
        matrix_type = _C_API_MATRIX_TYPE_CSR
1468
        out_ptr_indptr: _ctypes_int_ptr
1469
1470
1471
1472
1473
        if type_ptr_indptr == _C_API_DTYPE_INT32:
            out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)()
        else:
            out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)()
        out_ptr_indices = ctypes.POINTER(ctypes.c_int32)()
1474
        out_ptr_data: _ctypes_float_ptr
1475
1476
1477
1478
1479
        if type_ptr_data == _C_API_DTYPE_FLOAT32:
            out_ptr_data = ctypes.POINTER(ctypes.c_float)()
        else:
            out_ptr_data = ctypes.POINTER(ctypes.c_double)()
        out_shape = np.empty(2, dtype=np.int64)
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
        _safe_call(
            _LIB.LGBM_BoosterPredictSparseOutput(
                self._handle,
                ptr_indptr,
                ctypes.c_int(type_ptr_indptr),
                csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
                ptr_data,
                ctypes.c_int(type_ptr_data),
                ctypes.c_int64(len(csr.indptr)),
                ctypes.c_int64(len(csr.data)),
                ctypes.c_int64(csr.shape[1]),
                ctypes.c_int(predict_type),
                ctypes.c_int(start_iteration),
                ctypes.c_int(num_iteration),
                _c_str(self.pred_parameter),
                ctypes.c_int(matrix_type),
                out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
                ctypes.byref(out_ptr_indptr),
                ctypes.byref(out_ptr_indices),
                ctypes.byref(out_ptr_data),
            )
        )
1502
1503
1504
1505
1506
1507
1508
1509
        matrices = self.__create_sparse_native(
            cs=csr,
            out_shape=out_shape,
            out_ptr_indptr=out_ptr_indptr,
            out_ptr_indices=out_ptr_indices,
            out_ptr_data=out_ptr_data,
            indptr_type=type_ptr_indptr,
            data_type=type_ptr_data,
1510
            is_csr=True,
1511
1512
1513
1514
        )
        nrow = len(csr.indptr) - 1
        return matrices, nrow

1515
1516
1517
1518
1519
    def __pred_for_csr(
        self,
        csr: scipy.sparse.csr_matrix,
        start_iteration: int,
        num_iteration: int,
1520
        predict_type: int,
1521
    ) -> Tuple[np.ndarray, int]:
1522
        """Predict for a CSR data."""
1523
        if predict_type == _C_API_PREDICT_CONTRIB:
1524
1525
1526
1527
            return self.__inner_predict_csr_sparse(
                csr=csr,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1528
                predict_type=predict_type,
1529
            )
1530
        nrow = len(csr.indptr) - 1
1531
1532
        if nrow > _MAX_INT32:
            sections = [0] + list(np.arange(start=_MAX_INT32, stop=nrow, step=_MAX_INT32)) + [nrow]
1533
            # __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal
1534
            n_preds = [self.__get_num_preds(start_iteration, num_iteration, i, predict_type) for i in np.diff(sections)]
1535
            n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
1536
            preds = np.empty(sum(n_preds), dtype=np.float64)
1537
1538
1539
            for (start_idx, end_idx), (start_idx_pred, end_idx_pred) in zip(
                zip(sections, sections[1:]), zip(n_preds_sections, n_preds_sections[1:])
            ):
1540
                # avoid memory consumption by arrays concatenation operations
1541
1542
1543
1544
1545
                self.__inner_predict_csr(
                    csr=csr[start_idx:end_idx],
                    start_iteration=start_iteration,
                    num_iteration=num_iteration,
                    predict_type=predict_type,
1546
                    preds=preds[start_idx_pred:end_idx_pred],
1547
                )
1548
1549
            return preds, nrow
        else:
1550
1551
1552
1553
1554
            return self.__inner_predict_csr(
                csr=csr,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type,
1555
                preds=None,
1556
1557
1558
1559
            )

    def __inner_predict_sparse_csc(
        self,
1560
1561
1562
        csc: scipy.sparse.csc_matrix,
        start_iteration: int,
        num_iteration: int,
1563
        predict_type: int,
1564
    ) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]:
1565
1566
1567
1568
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
        csc_indices = csc.indices.astype(np.int32, copy=False)
        matrix_type = _C_API_MATRIX_TYPE_CSC
1569
        out_ptr_indptr: _ctypes_int_ptr
1570
1571
1572
1573
1574
        if type_ptr_indptr == _C_API_DTYPE_INT32:
            out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)()
        else:
            out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)()
        out_ptr_indices = ctypes.POINTER(ctypes.c_int32)()
1575
        out_ptr_data: _ctypes_float_ptr
1576
1577
1578
1579
1580
        if type_ptr_data == _C_API_DTYPE_FLOAT32:
            out_ptr_data = ctypes.POINTER(ctypes.c_float)()
        else:
            out_ptr_data = ctypes.POINTER(ctypes.c_double)()
        out_shape = np.empty(2, dtype=np.int64)
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
        _safe_call(
            _LIB.LGBM_BoosterPredictSparseOutput(
                self._handle,
                ptr_indptr,
                ctypes.c_int(type_ptr_indptr),
                csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
                ptr_data,
                ctypes.c_int(type_ptr_data),
                ctypes.c_int64(len(csc.indptr)),
                ctypes.c_int64(len(csc.data)),
                ctypes.c_int64(csc.shape[0]),
                ctypes.c_int(predict_type),
                ctypes.c_int(start_iteration),
                ctypes.c_int(num_iteration),
                _c_str(self.pred_parameter),
                ctypes.c_int(matrix_type),
                out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
                ctypes.byref(out_ptr_indptr),
                ctypes.byref(out_ptr_indices),
                ctypes.byref(out_ptr_data),
            )
        )
1603
1604
1605
1606
1607
1608
1609
1610
        matrices = self.__create_sparse_native(
            cs=csc,
            out_shape=out_shape,
            out_ptr_indptr=out_ptr_indptr,
            out_ptr_indices=out_ptr_indices,
            out_ptr_data=out_ptr_data,
            indptr_type=type_ptr_indptr,
            data_type=type_ptr_data,
1611
            is_csr=False,
1612
1613
1614
        )
        nrow = csc.shape[0]
        return matrices, nrow
Guolin Ke's avatar
Guolin Ke committed
1615

1616
1617
1618
1619
1620
    def __pred_for_csc(
        self,
        csc: scipy.sparse.csc_matrix,
        start_iteration: int,
        num_iteration: int,
1621
        predict_type: int,
1622
    ) -> Tuple[np.ndarray, int]:
1623
        """Predict for a CSC data."""
Guolin Ke's avatar
Guolin Ke committed
1624
        nrow = csc.shape[0]
1625
        if nrow > _MAX_INT32:
1626
1627
1628
1629
            return self.__pred_for_csr(
                csr=csc.tocsr(),
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1630
                predict_type=predict_type,
1631
            )
1632
        if predict_type == _C_API_PREDICT_CONTRIB:
1633
1634
1635
1636
            return self.__inner_predict_sparse_csc(
                csc=csc,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
1637
                predict_type=predict_type,
1638
            )
1639
1640
1641
1642
        n_preds = self.__get_num_preds(
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            nrow=nrow,
1643
            predict_type=predict_type,
1644
        )
1645
        preds = np.empty(n_preds, dtype=np.float64)
Guolin Ke's avatar
Guolin Ke committed
1646
1647
        out_num_preds = ctypes.c_int64(0)

1648
1649
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
Guolin Ke's avatar
Guolin Ke committed
1650

1651
        assert csc.shape[0] <= _MAX_INT32
1652
        csc_indices = csc.indices.astype(np.int32, copy=False)
1653

1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
        _safe_call(
            _LIB.LGBM_BoosterPredictForCSC(
                self._handle,
                ptr_indptr,
                ctypes.c_int(type_ptr_indptr),
                csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
                ptr_data,
                ctypes.c_int(type_ptr_data),
                ctypes.c_int64(len(csc.indptr)),
                ctypes.c_int64(len(csc.data)),
                ctypes.c_int64(csc.shape[0]),
                ctypes.c_int(predict_type),
                ctypes.c_int(start_iteration),
                ctypes.c_int(num_iteration),
                _c_str(self.pred_parameter),
                ctypes.byref(out_num_preds),
                preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
            )
        )
wxchan's avatar
wxchan committed
1673
        if n_preds != out_num_preds.value:
1674
            raise ValueError("Wrong length for predict results")
wxchan's avatar
wxchan committed
1675
        return preds, nrow
1676

1677
1678
1679
1680
1681
    def __pred_for_pyarrow_table(
        self,
        table: pa_Table,
        start_iteration: int,
        num_iteration: int,
1682
        predict_type: int,
1683
1684
1685
1686
1687
1688
    ) -> Tuple[np.ndarray, int]:
        """Predict for a PyArrow table."""
        if not PYARROW_INSTALLED:
            raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")

        # Check that the input is valid: we only handle numbers (for now)
1689
        if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
1690
1691
1692
1693
1694
1695
1696
            raise ValueError("Arrow table may only have integer or floating point datatypes")

        # Prepare prediction output array
        n_preds = self.__get_num_preds(
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            nrow=table.num_rows,
1697
            predict_type=predict_type,
1698
1699
1700
1701
1702
1703
        )
        preds = np.empty(n_preds, dtype=np.float64)
        out_num_preds = ctypes.c_int64(0)

        # Export Arrow table to C and run prediction
        c_array = _export_arrow_to_c(table)
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
        _safe_call(
            _LIB.LGBM_BoosterPredictForArrow(
                self._handle,
                ctypes.c_int64(c_array.n_chunks),
                ctypes.c_void_p(c_array.chunks_ptr),
                ctypes.c_void_p(c_array.schema_ptr),
                ctypes.c_int(predict_type),
                ctypes.c_int(start_iteration),
                ctypes.c_int(num_iteration),
                _c_str(self.pred_parameter),
                ctypes.byref(out_num_preds),
                preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
            )
        )
1718
1719
1720
        if n_preds != out_num_preds.value:
            raise ValueError("Wrong length for predict results")
        return preds, table.num_rows
wxchan's avatar
wxchan committed
1721

1722
    def current_iteration(self) -> int:
1723
1724
1725
1726
1727
1728
1729
1730
        """Get the index of the current iteration.

        Returns
        -------
        cur_iter : int
            The index of the current iteration.
        """
        out_cur_iter = ctypes.c_int(0)
1731
1732
1733
1734
1735
1736
        _safe_call(
            _LIB.LGBM_BoosterGetCurrentIteration(
                self._handle,
                ctypes.byref(out_cur_iter),
            )
        )
1737
1738
        return out_cur_iter.value

wxchan's avatar
wxchan committed
1739

1740
class Dataset:
1741
1742
1743
1744
1745
1746
1747
1748
1749
    """
    Dataset in LightGBM.

    LightGBM does not train on raw data.
    It discretizes continuous features into histogram bins, tries to combine categorical features,
    and automatically handles missing and infinite values.

    This class handles that preprocessing, and holds that alternative representation of the input data.
    """
1750

1751
1752
    def __init__(
        self,
1753
        data: _LGBM_TrainDataType,
1754
        label: Optional[_LGBM_LabelType] = None,
1755
        reference: Optional["Dataset"] = None,
1756
1757
1758
        weight: Optional[_LGBM_WeightType] = None,
        group: Optional[_LGBM_GroupType] = None,
        init_score: Optional[_LGBM_InitScoreType] = None,
1759
1760
        feature_name: _LGBM_FeatureNameConfiguration = "auto",
        categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
1761
        params: Optional[Dict[str, Any]] = None,
1762
1763
        free_raw_data: bool = True,
        position: Optional[_LGBM_PositionType] = None,
1764
    ):
1765
        """Initialize Dataset.
1766

wxchan's avatar
wxchan committed
1767
1768
        Parameters
        ----------
1769
        data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table
wxchan's avatar
wxchan committed
1770
            Data source of Dataset.
1771
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
1772
        label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
1773
1774
1775
            Label of the data.
        reference : Dataset or None, optional (default=None)
            If this is Dataset for validation, training data should be used as reference.
1776
        weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
1777
            Weight for each instance. Weights should be non-negative.
1778
        group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
1779
1780
1781
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
1782
1783
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
1784
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
1785
            Init score for Dataset.
1786
        feature_name : list of str, or 'auto', optional (default="auto")
1787
            Feature names.
1788
            If 'auto' and data is pandas DataFrame or pyarrow Table, data columns names are used.
1789
        categorical_feature : list of str or int, or 'auto', optional (default="auto")
1790
1791
            Categorical features.
            If list of int, interpreted as indices.
1792
            If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
1793
            If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
1794
            All values in categorical features will be cast to int32 and thus should be less than int32 max value (2147483647).
1795
            Large values could be memory consuming. Consider using consecutive integers starting from zero.
1796
            All negative values in categorical features will be treated as missing values.
1797
            The output cannot be monotonically constrained with respect to a categorical feature.
1798
            Floating point numbers in categorical features will be rounded towards 0.
Nikita Titov's avatar
Nikita Titov committed
1799
        params : dict or None, optional (default=None)
1800
            Other parameters for Dataset.
Nikita Titov's avatar
Nikita Titov committed
1801
        free_raw_data : bool, optional (default=True)
1802
            If True, raw data is freed after constructing inner Dataset.
1803
1804
        position : numpy 1-D array, pandas Series or None, optional (default=None)
            Position of items used in unbiased learning-to-rank task.
wxchan's avatar
wxchan committed
1805
        """
1806
        self._handle: Optional[_DatasetHandle] = None
wxchan's avatar
wxchan committed
1807
1808
1809
1810
1811
        self.data = data
        self.label = label
        self.reference = reference
        self.weight = weight
        self.group = group
1812
        self.position = position
1813
        self.init_score = init_score
1814
1815
        self.feature_name: _LGBM_FeatureNameConfiguration = feature_name
        self.categorical_feature: _LGBM_CategoricalFeatureConfiguration = categorical_feature
1816
        self.params = deepcopy(params)
wxchan's avatar
wxchan committed
1817
        self.free_raw_data = free_raw_data
1818
        self.used_indices: Optional[List[int]] = None
1819
        self._need_slice = True
1820
        self._predictor: Optional[_InnerPredictor] = None
1821
        self.pandas_categorical: Optional[List[List]] = None
1822
        self._params_back_up: Optional[Dict[str, Any]] = None
1823
        self.version = 0
1824
        self._start_row = 0  # Used when pushing rows one by one.
wxchan's avatar
wxchan committed
1825

1826
    def __del__(self) -> None:
1827
1828
1829
1830
        try:
            self._free_handle()
        except AttributeError:
            pass
1831

1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
    def _create_sample_indices(self, total_nrow: int) -> np.ndarray:
        """Get an array of randomly chosen indices from this ``Dataset``.

        Indices are sampled without replacement.

        Parameters
        ----------
        total_nrow : int
            Total number of rows to sample from.
            If this value is greater than the value of parameter ``bin_construct_sample_cnt``, only ``bin_construct_sample_cnt`` indices will be used.
            If Dataset has multiple input data, this should be the sum of rows of every file.

        Returns
        -------
        indices : numpy array
            Indices for sampled data.
        """
1849
        param_str = _param_dict_to_str(self.get_params())
1850
1851
        sample_cnt = _get_sample_count(total_nrow, param_str)
        indices = np.empty(sample_cnt, dtype=np.int32)
1852
        ptr_data, _, _ = _c_int_array(indices)
1853
1854
        actual_sample_cnt = ctypes.c_int32(0)

1855
1856
1857
1858
1859
1860
1861
1862
        _safe_call(
            _LIB.LGBM_SampleIndices(
                ctypes.c_int32(total_nrow),
                _c_str(param_str),
                ptr_data,
                ctypes.byref(actual_sample_cnt),
            )
        )
1863
1864
        assert sample_cnt == actual_sample_cnt.value
        return indices
1865

1866
1867
1868
    def _init_from_ref_dataset(
        self,
        total_nrow: int,
1869
1870
        ref_dataset: _DatasetHandle,
    ) -> "Dataset":
1871
1872
1873
1874
1875
1876
        """Create dataset from a reference dataset.

        Parameters
        ----------
        total_nrow : int
            Number of rows expected to add to dataset.
1877
1878
        ref_dataset : object
            Handle of reference dataset to extract metadata from.
1879
1880
1881
1882
1883
1884

        Returns
        -------
        self : Dataset
            Constructed Dataset object.
        """
1885
        self._handle = ctypes.c_void_p()
1886
1887
1888
1889
1890
1891
1892
        _safe_call(
            _LIB.LGBM_DatasetCreateByReference(
                ref_dataset,
                ctypes.c_int64(total_nrow),
                ctypes.byref(self._handle),
            )
        )
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
        return self

    def _init_from_sample(
        self,
        sample_data: List[np.ndarray],
        sample_indices: List[np.ndarray],
        sample_cnt: int,
        total_nrow: int,
    ) -> "Dataset":
        """Create Dataset from sampled data structures.

        Parameters
        ----------
1906
        sample_data : list of numpy array
1907
            Sample data for each column.
1908
        sample_indices : list of numpy array
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
            Sample data row index for each column.
        sample_cnt : int
            Number of samples.
        total_nrow : int
            Total number of rows for all input files.

        Returns
        -------
        self : Dataset
            Constructed Dataset object.
        """
        ncol = len(sample_indices)
        assert len(sample_data) == ncol, "#sample data column != #column indices"

        for i in range(ncol):
            if sample_data[i].dtype != np.double:
                raise ValueError(f"sample_data[{i}] type {sample_data[i].dtype} is not double")
            if sample_indices[i].dtype != np.int32:
                raise ValueError(f"sample_indices[{i}] type {sample_indices[i].dtype} is not int32")

        # c type: double**
        # each double* element points to start of each column of sample data.
1931
        sample_col_ptr: _ctypes_float_array = (ctypes.POINTER(ctypes.c_double) * ncol)()
1932
1933
        # c type int**
        # each int* points to start of indices for each column
1934
        indices_col_ptr: _ctypes_int_array = (ctypes.POINTER(ctypes.c_int32) * ncol)()
1935
        for i in range(ncol):
1936
1937
            sample_col_ptr[i] = _c_float_array(sample_data[i])[0]
            indices_col_ptr[i] = _c_int_array(sample_indices[i])[0]
1938
1939

        num_per_col = np.array([len(d) for d in sample_indices], dtype=np.int32)
1940
        num_per_col_ptr, _, _ = _c_int_array(num_per_col)
1941

1942
        self._handle = ctypes.c_void_p()
1943
        params_str = _param_dict_to_str(self.get_params())
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
        _safe_call(
            _LIB.LGBM_DatasetCreateFromSampledColumn(
                ctypes.cast(sample_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))),
                ctypes.cast(indices_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_int32))),
                ctypes.c_int32(ncol),
                num_per_col_ptr,
                ctypes.c_int32(sample_cnt),
                ctypes.c_int32(total_nrow),
                ctypes.c_int64(total_nrow),
                _c_str(params_str),
                ctypes.byref(self._handle),
            )
        )
1957
1958
        return self

1959
    def _push_rows(self, data: np.ndarray) -> "Dataset":
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
        """Add rows to Dataset.

        Parameters
        ----------
        data : numpy 1-D array
            New data to add to the Dataset.

        Returns
        -------
        self : Dataset
            Dataset object.
        """
        nrow, ncol = data.shape
        data = data.reshape(data.size)
1974
        data_ptr, data_type, _ = _c_float_array(data)
1975

1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
        _safe_call(
            _LIB.LGBM_DatasetPushRows(
                self._handle,
                data_ptr,
                data_type,
                ctypes.c_int32(nrow),
                ctypes.c_int32(ncol),
                ctypes.c_int32(self._start_row),
            )
        )
1986
1987
1988
        self._start_row += nrow
        return self

1989
    def get_params(self) -> Dict[str, Any]:
1990
1991
1992
1993
        """Get the used parameters in the Dataset.

        Returns
        -------
1994
        params : dict
1995
1996
1997
1998
            The used parameters in this Dataset object.
        """
        if self.params is not None:
            # no min_data, nthreads and verbose in this function
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
            dataset_params = _ConfigAliases.get(
                "bin_construct_sample_cnt",
                "categorical_feature",
                "data_random_seed",
                "enable_bundle",
                "feature_pre_filter",
                "forcedbins_filename",
                "group_column",
                "header",
                "ignore_column",
                "is_enable_sparse",
                "label_column",
                "linear_tree",
                "max_bin",
                "max_bin_by_feature",
                "min_data_in_bin",
                "pre_partition",
                "precise_float_parser",
                "two_round",
                "use_missing",
                "weight_column",
                "zero_as_missing",
            )
2022
            return {k: v for k, v in self.params.items() if k in dataset_params}
2023
2024
        else:
            return {}
2025

2026
    def _free_handle(self) -> "Dataset":
2027
2028
2029
        if self._handle is not None:
            _safe_call(_LIB.LGBM_DatasetFree(self._handle))
            self._handle = None
2030
        self._need_slice = True
Guolin Ke's avatar
Guolin Ke committed
2031
2032
        if self.used_indices is not None:
            self.data = None
Nikita Titov's avatar
Nikita Titov committed
2033
        return self
wxchan's avatar
wxchan committed
2034

2035
2036
2037
    def _set_init_score_by_predictor(
        self,
        predictor: Optional[_InnerPredictor],
2038
        data: _LGBM_TrainDataType,
2039
        used_indices: Optional[Union[List[int], np.ndarray]],
2040
    ) -> "Dataset":
Guolin Ke's avatar
Guolin Ke committed
2041
        data_has_header = False
2042
        if isinstance(data, (str, Path)) and self.params is not None:
Guolin Ke's avatar
Guolin Ke committed
2043
            # check data has header or not
2044
            data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header"))
Guolin Ke's avatar
Guolin Ke committed
2045
        num_data = self.num_data()
2046
        if predictor is not None:
2047
2048
2049
            init_score: Union[np.ndarray, scipy.sparse.spmatrix] = predictor.predict(
                data=data,
                raw_score=True,
2050
                data_has_header=data_has_header,
2051
            )
2052
            init_score = init_score.ravel()
2053
            if used_indices is not None:
2054
                assert not self._need_slice
2055
                if isinstance(data, (str, Path)):
2056
                    sub_init_score = np.empty(num_data * predictor.num_class, dtype=np.float64)
2057
                    assert num_data == len(used_indices)
2058
2059
                    for i in range(len(used_indices)):
                        for j in range(predictor.num_class):
2060
2061
2062
                            sub_init_score[i * predictor.num_class + j] = init_score[
                                used_indices[i] * predictor.num_class + j
                            ]
2063
2064
2065
                    init_score = sub_init_score
            if predictor.num_class > 1:
                # need to regroup init_score
2066
                new_init_score = np.empty(init_score.size, dtype=np.float64)
2067
2068
                for i in range(num_data):
                    for j in range(predictor.num_class):
2069
2070
2071
                        new_init_score[j * num_data + i] = init_score[i * predictor.num_class + j]
                init_score = new_init_score
        elif self.init_score is not None:
2072
            init_score = np.full_like(self.init_score, fill_value=0.0, dtype=np.float64)
2073
2074
        else:
            return self
Guolin Ke's avatar
Guolin Ke committed
2075
        self.set_init_score(init_score)
2076
        return self
Guolin Ke's avatar
Guolin Ke committed
2077

2078
2079
    def _lazy_init(
        self,
2080
        data: Optional[_LGBM_TrainDataType],
2081
2082
2083
2084
2085
2086
2087
2088
        label: Optional[_LGBM_LabelType],
        reference: Optional["Dataset"],
        weight: Optional[_LGBM_WeightType],
        group: Optional[_LGBM_GroupType],
        init_score: Optional[_LGBM_InitScoreType],
        predictor: Optional[_InnerPredictor],
        feature_name: _LGBM_FeatureNameConfiguration,
        categorical_feature: _LGBM_CategoricalFeatureConfiguration,
2089
        params: Optional[Dict[str, Any]],
2090
        position: Optional[_LGBM_PositionType],
2091
    ) -> "Dataset":
wxchan's avatar
wxchan committed
2092
        if data is None:
2093
            self._handle = None
Nikita Titov's avatar
Nikita Titov committed
2094
            return self
Guolin Ke's avatar
Guolin Ke committed
2095
2096
2097
        if reference is not None:
            self.pandas_categorical = reference.pandas_categorical
            categorical_feature = reference.categorical_feature
2098
2099
2100
2101
2102
        if isinstance(data, pd_DataFrame):
            data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas(
                data=data,
                feature_name=feature_name,
                categorical_feature=categorical_feature,
2103
                pandas_categorical=self.pandas_categorical,
2104
            )
Guolin Ke's avatar
Guolin Ke committed
2105

2106
        # process for args
wxchan's avatar
wxchan committed
2107
        params = {} if params is None else params
2108
        args_names = inspect.signature(self.__class__._lazy_init).parameters.keys()
2109
        for key in params.keys():
2110
            if key in args_names:
2111
2112
2113
2114
                _log_warning(
                    f"{key} keyword has been found in `params` and will be ignored.\n"
                    f"Please use {key} argument of the Dataset constructor to pass this parameter."
                )
2115
        # get categorical features
2116
        if isinstance(categorical_feature, list):
2117
2118
            categorical_indices = set()
            feature_dict = {}
2119
            if isinstance(feature_name, list):
2120
2121
                feature_dict = {name: i for i, name in enumerate(feature_name)}
            for name in categorical_feature:
2122
                if isinstance(name, str) and name in feature_dict:
2123
                    categorical_indices.add(feature_dict[name])
2124
                elif isinstance(name, int):
2125
2126
                    categorical_indices.add(name)
                else:
2127
                    raise TypeError(f"Wrong type({type(name).__name__}) or unknown name({name}) in categorical_feature")
2128
            if categorical_indices:
2129
2130
                for cat_alias in _ConfigAliases.get("categorical_feature"):
                    if cat_alias in params:
2131
                        # If the params[cat_alias] is equal to categorical_indices, do not report the warning.
2132
                        if not (isinstance(params[cat_alias], list) and set(params[cat_alias]) == categorical_indices):
2133
                            _log_warning(f"{cat_alias} in param dict is overridden.")
2134
                        params.pop(cat_alias, None)
2135
                params["categorical_column"] = sorted(categorical_indices)
2136

2137
        params_str = _param_dict_to_str(params)
2138
        self.params = params
2139
        # process for reference dataset
wxchan's avatar
wxchan committed
2140
        ref_dataset = None
wxchan's avatar
wxchan committed
2141
        if isinstance(reference, Dataset):
2142
            ref_dataset = reference.construct()._handle
wxchan's avatar
wxchan committed
2143
        elif reference is not None:
2144
            raise TypeError("Reference dataset should be None or dataset instance")
2145
        # start construct data
2146
        if isinstance(data, (str, Path)):
2147
            self._handle = ctypes.c_void_p()
2148
2149
2150
2151
2152
2153
2154
2155
            _safe_call(
                _LIB.LGBM_DatasetCreateFromFile(
                    _c_str(str(data)),
                    _c_str(params_str),
                    ref_dataset,
                    ctypes.byref(self._handle),
                )
            )
wxchan's avatar
wxchan committed
2156
2157
        elif isinstance(data, scipy.sparse.csr_matrix):
            self.__init_from_csr(data, params_str, ref_dataset)
Guolin Ke's avatar
Guolin Ke committed
2158
2159
        elif isinstance(data, scipy.sparse.csc_matrix):
            self.__init_from_csc(data, params_str, ref_dataset)
wxchan's avatar
wxchan committed
2160
2161
        elif isinstance(data, np.ndarray):
            self.__init_from_np2d(data, params_str, ref_dataset)
2162
2163
2164
        elif _is_pyarrow_table(data):
            self.__init_from_pyarrow_table(data, params_str, ref_dataset)
            feature_name = data.column_names
2165
        elif isinstance(data, list) and len(data) > 0:
2166
            if _is_list_of_numpy_arrays(data):
2167
                self.__init_from_list_np2d(data, params_str, ref_dataset)
2168
            elif _is_list_of_sequences(data):
2169
2170
                self.__init_from_seqs(data, ref_dataset)
            else:
2171
                raise TypeError("Data list can only be of ndarray or Sequence")
2172
2173
        elif isinstance(data, Sequence):
            self.__init_from_seqs([data], ref_dataset)
2174
        elif isinstance(data, dt_DataTable):
2175
            self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset)
wxchan's avatar
wxchan committed
2176
2177
2178
2179
        else:
            try:
                csr = scipy.sparse.csr_matrix(data)
                self.__init_from_csr(csr, params_str, ref_dataset)
2180
            except BaseException as err:
2181
                raise TypeError(f"Cannot initialize Dataset from {type(data).__name__}") from err
wxchan's avatar
wxchan committed
2182
2183
2184
        if label is not None:
            self.set_label(label)
        if self.get_label() is None:
2185
            raise ValueError("Label should not be None")
wxchan's avatar
wxchan committed
2186
2187
2188
2189
        if weight is not None:
            self.set_weight(weight)
        if group is not None:
            self.set_group(group)
2190
2191
        if position is not None:
            self.set_position(position)
2192
2193
        if isinstance(predictor, _InnerPredictor):
            if self._predictor is None and init_score is not None:
2194
                _log_warning("The init_score will be overridden by the prediction of init_model.")
2195
            self._set_init_score_by_predictor(predictor=predictor, data=data, used_indices=None)
2196
2197
        elif init_score is not None:
            self.set_init_score(init_score)
Guolin Ke's avatar
Guolin Ke committed
2198
        elif predictor is not None:
2199
            raise TypeError(f"Wrong predictor type {type(predictor).__name__}")
Guolin Ke's avatar
Guolin Ke committed
2200
        # set feature names
Nikita Titov's avatar
Nikita Titov committed
2201
        return self.set_feature_name(feature_name)
wxchan's avatar
wxchan committed
2202

2203
    @staticmethod
2204
    def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]:
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
        offset = 0
        seq_id = 0
        seq = seqs[seq_id]
        for row_id in indices:
            assert row_id >= offset, "sample indices are expected to be monotonic"
            while row_id >= offset + len(seq):
                offset += len(seq)
                seq_id += 1
                seq = seqs[seq_id]
            id_in_seq = row_id - offset
            row = seq[id_in_seq]
2216
            yield row if row.flags["OWNDATA"] else row.copy()
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229

    def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """Sample data from seqs.

        Mimics behavior in c_api.cpp:LGBM_DatasetCreateFromMats()

        Returns
        -------
            sampled_rows, sampled_row_indices
        """
        indices = self._create_sample_indices(total_nrow)

        # Select sampled rows, transpose to column order.
2230
        sampled = np.array(list(self._yield_row_from_seqlist(seqs, indices)))
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
        sampled = sampled.T

        filtered = []
        filtered_idx = []
        sampled_row_range = np.arange(len(indices), dtype=np.int32)
        for col in sampled:
            col_predicate = (np.abs(col) > ZERO_THRESHOLD) | np.isnan(col)
            filtered_col = col[col_predicate]
            filtered_row_idx = sampled_row_range[col_predicate]

            filtered.append(filtered_col)
            filtered_idx.append(filtered_row_idx)

        return filtered, filtered_idx

2246
2247
2248
    def __init_from_seqs(
        self,
        seqs: List[Sequence],
2249
        ref_dataset: Optional[_DatasetHandle],
2250
    ) -> "Dataset":
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
        """
        Initialize data from list of Sequence objects.

        Sequence: Generic Data Access Object
            Supports random access and access by batch if properly defined by user

        Data scheme uniformity are trusted, not checked
        """
        total_nrow = sum(len(seq) for seq in seqs)

        # create validation dataset from ref_dataset
        if ref_dataset is not None:
            self._init_from_ref_dataset(total_nrow, ref_dataset)
        else:
2265
            param_str = _param_dict_to_str(self.get_params())
2266
2267
2268
2269
2270
2271
2272
            sample_cnt = _get_sample_count(total_nrow, param_str)

            sample_data, col_indices = self.__sample(seqs, total_nrow)
            self._init_from_sample(sample_data, col_indices, sample_cnt, total_nrow)

        for seq in seqs:
            nrow = len(seq)
2273
            batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size
2274
2275
2276
2277
2278
            for start in range(0, nrow, batch_size):
                end = min(start + batch_size, nrow)
                self._push_rows(seq[start:end])
        return self

2279
2280
2281
2282
    def __init_from_np2d(
        self,
        mat: np.ndarray,
        params_str: str,
2283
        ref_dataset: Optional[_DatasetHandle],
2284
    ) -> "Dataset":
2285
        """Initialize data from a 2-D numpy matrix."""
wxchan's avatar
wxchan committed
2286
        if len(mat.shape) != 2:
2287
            raise ValueError("Input numpy.ndarray must be 2 dimensional")
wxchan's avatar
wxchan committed
2288

2289
        self._handle = ctypes.c_void_p()
wxchan's avatar
wxchan committed
2290
        if mat.dtype == np.float32 or mat.dtype == np.float64:
2291
            data = np.asarray(mat.reshape(mat.size), dtype=mat.dtype)
2292
        else:  # change non-float data to float data, need to copy
2293
            data = np.asarray(mat.reshape(mat.size), dtype=np.float32)
wxchan's avatar
wxchan committed
2294

2295
        ptr_data, type_ptr_data, _ = _c_float_array(data)
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
        _safe_call(
            _LIB.LGBM_DatasetCreateFromMat(
                ptr_data,
                ctypes.c_int(type_ptr_data),
                ctypes.c_int32(mat.shape[0]),
                ctypes.c_int32(mat.shape[1]),
                ctypes.c_int(_C_API_IS_ROW_MAJOR),
                _c_str(params_str),
                ref_dataset,
                ctypes.byref(self._handle),
            )
        )
Nikita Titov's avatar
Nikita Titov committed
2308
        return self
wxchan's avatar
wxchan committed
2309

2310
2311
2312
2313
    def __init_from_list_np2d(
        self,
        mats: List[np.ndarray],
        params_str: str,
2314
        ref_dataset: Optional[_DatasetHandle],
2315
    ) -> "Dataset":
2316
        """Initialize data from a list of 2-D numpy matrices."""
2317
        ncol = mats[0].shape[1]
2318
        nrow = np.empty((len(mats),), np.int32)
2319
        ptr_data: _ctypes_float_array
2320
2321
2322
2323
2324
2325
        if mats[0].dtype == np.float64:
            ptr_data = (ctypes.POINTER(ctypes.c_double) * len(mats))()
        else:
            ptr_data = (ctypes.POINTER(ctypes.c_float) * len(mats))()

        holders = []
2326
        type_ptr_data = -1
2327
2328
2329

        for i, mat in enumerate(mats):
            if len(mat.shape) != 2:
2330
                raise ValueError("Input numpy.ndarray must be 2 dimensional")
2331
2332

            if mat.shape[1] != ncol:
2333
                raise ValueError("Input arrays must have same number of columns")
2334
2335
2336
2337

            nrow[i] = mat.shape[0]

            if mat.dtype == np.float32 or mat.dtype == np.float64:
2338
                mats[i] = np.asarray(mat.reshape(mat.size), dtype=mat.dtype)
2339
            else:  # change non-float data to float data, need to copy
2340
2341
                mats[i] = np.array(mat.reshape(mat.size), dtype=np.float32)

2342
            chunk_ptr_data, chunk_type_ptr_data, holder = _c_float_array(mats[i])
2343
            if type_ptr_data != -1 and chunk_type_ptr_data != type_ptr_data:
2344
                raise ValueError("Input chunks must have same type")
2345
2346
2347
2348
            ptr_data[i] = chunk_ptr_data
            type_ptr_data = chunk_type_ptr_data
            holders.append(holder)

2349
        self._handle = ctypes.c_void_p()
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
        _safe_call(
            _LIB.LGBM_DatasetCreateFromMats(
                ctypes.c_int32(len(mats)),
                ctypes.cast(ptr_data, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))),
                ctypes.c_int(type_ptr_data),
                nrow.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
                ctypes.c_int32(ncol),
                ctypes.c_int(_C_API_IS_ROW_MAJOR),
                _c_str(params_str),
                ref_dataset,
                ctypes.byref(self._handle),
            )
        )
Nikita Titov's avatar
Nikita Titov committed
2363
        return self
2364

2365
2366
2367
2368
    def __init_from_csr(
        self,
        csr: scipy.sparse.csr_matrix,
        params_str: str,
2369
        ref_dataset: Optional[_DatasetHandle],
2370
    ) -> "Dataset":
2371
        """Initialize data from a CSR matrix."""
wxchan's avatar
wxchan committed
2372
        if len(csr.indices) != len(csr.data):
2373
            raise ValueError(f"Length mismatch: {len(csr.indices)} vs {len(csr.data)}")
2374
        self._handle = ctypes.c_void_p()
wxchan's avatar
wxchan committed
2375

2376
2377
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
wxchan's avatar
wxchan committed
2378

2379
        assert csr.shape[1] <= _MAX_INT32
2380
        csr_indices = csr.indices.astype(np.int32, copy=False)
2381

2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
        _safe_call(
            _LIB.LGBM_DatasetCreateFromCSR(
                ptr_indptr,
                ctypes.c_int(type_ptr_indptr),
                csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
                ptr_data,
                ctypes.c_int(type_ptr_data),
                ctypes.c_int64(len(csr.indptr)),
                ctypes.c_int64(len(csr.data)),
                ctypes.c_int64(csr.shape[1]),
                _c_str(params_str),
                ref_dataset,
                ctypes.byref(self._handle),
            )
        )
Nikita Titov's avatar
Nikita Titov committed
2397
        return self
wxchan's avatar
wxchan committed
2398

2399
2400
2401
2402
    def __init_from_csc(
        self,
        csc: scipy.sparse.csc_matrix,
        params_str: str,
2403
        ref_dataset: Optional[_DatasetHandle],
2404
    ) -> "Dataset":
2405
        """Initialize data from a CSC matrix."""
Guolin Ke's avatar
Guolin Ke committed
2406
        if len(csc.indices) != len(csc.data):
2407
            raise ValueError(f"Length mismatch: {len(csc.indices)} vs {len(csc.data)}")
2408
        self._handle = ctypes.c_void_p()
Guolin Ke's avatar
Guolin Ke committed
2409

2410
2411
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
Guolin Ke's avatar
Guolin Ke committed
2412

2413
        assert csc.shape[0] <= _MAX_INT32
2414
        csc_indices = csc.indices.astype(np.int32, copy=False)
2415

2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
        _safe_call(
            _LIB.LGBM_DatasetCreateFromCSC(
                ptr_indptr,
                ctypes.c_int(type_ptr_indptr),
                csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
                ptr_data,
                ctypes.c_int(type_ptr_data),
                ctypes.c_int64(len(csc.indptr)),
                ctypes.c_int64(len(csc.data)),
                ctypes.c_int64(csc.shape[0]),
                _c_str(params_str),
                ref_dataset,
                ctypes.byref(self._handle),
            )
        )
Nikita Titov's avatar
Nikita Titov committed
2431
        return self
Guolin Ke's avatar
Guolin Ke committed
2432

2433
2434
2435
2436
    def __init_from_pyarrow_table(
        self,
        table: pa_Table,
        params_str: str,
2437
        ref_dataset: Optional[_DatasetHandle],
2438
2439
2440
2441
2442
2443
    ) -> "Dataset":
        """Initialize data from a PyArrow table."""
        if not PYARROW_INSTALLED:
            raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.")

        # Check that the input is valid: we only handle numbers (for now)
2444
        if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
2445
2446
2447
2448
2449
            raise ValueError("Arrow table may only have integer or floating point datatypes")

        # Export Arrow table to C
        c_array = _export_arrow_to_c(table)
        self._handle = ctypes.c_void_p()
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
        _safe_call(
            _LIB.LGBM_DatasetCreateFromArrow(
                ctypes.c_int64(c_array.n_chunks),
                ctypes.c_void_p(c_array.chunks_ptr),
                ctypes.c_void_p(c_array.schema_ptr),
                _c_str(params_str),
                ref_dataset,
                ctypes.byref(self._handle),
            )
        )
2460
2461
        return self

2462
    @staticmethod
2463
    def _compare_params_for_warning(
2464
2465
        params: Dict[str, Any],
        other_params: Dict[str, Any],
2466
        ignore_keys: Set[str],
2467
2468
    ) -> bool:
        """Compare two dictionaries with params ignoring some keys.
2469

2470
2471
2472
2473
        It is only for the warning purpose.

        Parameters
        ----------
2474
        params : dict
2475
            One dictionary with parameters to compare.
2476
        other_params : dict
2477
2478
2479
            Another dictionary with parameters to compare.
        ignore_keys : set
            Keys that should be ignored during comparing two dictionaries.
2480
2481
2482

        Returns
        -------
2483
2484
        compare_result : bool
          Returns whether two dictionaries with params are equal.
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
        """
        for k in other_params:
            if k not in ignore_keys:
                if k not in params or params[k] != other_params[k]:
                    return False
        for k in params:
            if k not in ignore_keys:
                if k not in other_params or params[k] != other_params[k]:
                    return False
        return True

2496
    def construct(self) -> "Dataset":
2497
2498
2499
2500
2501
        """Lazy init.

        Returns
        -------
        self : Dataset
Nikita Titov's avatar
Nikita Titov committed
2502
            Constructed Dataset object.
2503
        """
2504
        if self._handle is None:
wxchan's avatar
wxchan committed
2505
            if self.reference is not None:
2506
                reference_params = self.reference.get_params()
2507
2508
                params = self.get_params()
                if params != reference_params:
2509
2510
2511
                    if not self._compare_params_for_warning(
                        params=params,
                        other_params=reference_params,
2512
                        ignore_keys=_ConfigAliases.get("categorical_feature"),
2513
                    ):
2514
                        _log_warning("Overriding the parameters from Reference Dataset.")
2515
                    self._update_params(reference_params)
wxchan's avatar
wxchan committed
2516
                if self.used_indices is None:
2517
                    # create valid
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
                    self._lazy_init(
                        data=self.data,
                        label=self.label,
                        reference=self.reference,
                        weight=self.weight,
                        group=self.group,
                        position=self.position,
                        init_score=self.init_score,
                        predictor=self._predictor,
                        feature_name=self.feature_name,
                        categorical_feature="auto",
                        params=self.params,
                    )
wxchan's avatar
wxchan committed
2531
                else:
2532
                    # construct subset
2533
                    used_indices = _list_to_1d_numpy(self.used_indices, dtype=np.int32, name="used_indices")
2534
                    assert used_indices.flags.c_contiguous
Guolin Ke's avatar
Guolin Ke committed
2535
                    if self.reference.group is not None:
2536
                        group_info = np.array(self.reference.group).astype(np.int32, copy=False)
2537
2538
2539
                        _, self.group = np.unique(
                            np.repeat(range(len(group_info)), repeats=group_info)[self.used_indices], return_counts=True
                        )
2540
                    self._handle = ctypes.c_void_p()
2541
                    params_str = _param_dict_to_str(self.params)
2542
2543
2544
2545
2546
2547
2548
2549
2550
                    _safe_call(
                        _LIB.LGBM_DatasetGetSubset(
                            self.reference.construct()._handle,
                            used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
                            ctypes.c_int32(used_indices.shape[0]),
                            _c_str(params_str),
                            ctypes.byref(self._handle),
                        )
                    )
Guolin Ke's avatar
Guolin Ke committed
2551
2552
                    if not self.free_raw_data:
                        self.get_data()
Guolin Ke's avatar
Guolin Ke committed
2553
2554
                    if self.group is not None:
                        self.set_group(self.group)
2555
2556
                    if self.position is not None:
                        self.set_position(self.position)
wxchan's avatar
wxchan committed
2557
2558
                    if self.get_label() is None:
                        raise ValueError("Label should not be None.")
2559
2560
2561
2562
                    if (
                        isinstance(self._predictor, _InnerPredictor)
                        and self._predictor is not self.reference._predictor
                    ):
Guolin Ke's avatar
Guolin Ke committed
2563
                        self.get_data()
2564
                        self._set_init_score_by_predictor(
2565
                            predictor=self._predictor, data=self.data, used_indices=used_indices
2566
                        )
wxchan's avatar
wxchan committed
2567
            else:
2568
                # create train
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
                self._lazy_init(
                    data=self.data,
                    label=self.label,
                    reference=None,
                    weight=self.weight,
                    group=self.group,
                    init_score=self.init_score,
                    predictor=self._predictor,
                    feature_name=self.feature_name,
                    categorical_feature=self.categorical_feature,
                    params=self.params,
                    position=self.position,
                )
wxchan's avatar
wxchan committed
2582
2583
            if self.free_raw_data:
                self.data = None
2584
            self.feature_name = self.get_feature_name()
wxchan's avatar
wxchan committed
2585
        return self
wxchan's avatar
wxchan committed
2586

2587
2588
    def create_valid(
        self,
2589
        data: _LGBM_TrainDataType,
2590
        label: Optional[_LGBM_LabelType] = None,
2591
2592
2593
        weight: Optional[_LGBM_WeightType] = None,
        group: Optional[_LGBM_GroupType] = None,
        init_score: Optional[_LGBM_InitScoreType] = None,
2594
        params: Optional[Dict[str, Any]] = None,
2595
        position: Optional[_LGBM_PositionType] = None,
2596
    ) -> "Dataset":
2597
        """Create validation data align with current Dataset.
wxchan's avatar
wxchan committed
2598
2599
2600

        Parameters
        ----------
2601
        data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
wxchan's avatar
wxchan committed
2602
            Data source of Dataset.
2603
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
2604
        label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
2605
            Label of the data.
2606
        weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
2607
            Weight for each instance. Weights should be non-negative.
2608
        group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
2609
2610
2611
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
2612
2613
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
2614
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
2615
            Init score for Dataset.
Nikita Titov's avatar
Nikita Titov committed
2616
        params : dict or None, optional (default=None)
2617
            Other parameters for validation Dataset.
2618
2619
        position : numpy 1-D array, pandas Series or None, optional (default=None)
            Position of items used in unbiased learning-to-rank task.
2620
2621
2622

        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
2623
2624
        valid : Dataset
            Validation Dataset with reference to self.
wxchan's avatar
wxchan committed
2625
        """
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
        ret = Dataset(
            data,
            label=label,
            reference=self,
            weight=weight,
            group=group,
            position=position,
            init_score=init_score,
            params=params,
            free_raw_data=self.free_raw_data,
        )
wxchan's avatar
wxchan committed
2637
        ret._predictor = self._predictor
2638
        ret.pandas_categorical = self.pandas_categorical
wxchan's avatar
wxchan committed
2639
        return ret
wxchan's avatar
wxchan committed
2640

2641
2642
2643
    def subset(
        self,
        used_indices: List[int],
2644
        params: Optional[Dict[str, Any]] = None,
2645
    ) -> "Dataset":
2646
        """Get subset of current Dataset.
wxchan's avatar
wxchan committed
2647
2648
2649
2650

        Parameters
        ----------
        used_indices : list of int
2651
            Indices used to create the subset.
Nikita Titov's avatar
Nikita Titov committed
2652
        params : dict or None, optional (default=None)
2653
            These parameters will be passed to Dataset constructor.
2654
2655
2656
2657
2658

        Returns
        -------
        subset : Dataset
            Subset of the current Dataset.
wxchan's avatar
wxchan committed
2659
        """
wxchan's avatar
wxchan committed
2660
2661
        if params is None:
            params = self.params
2662
2663
2664
2665
2666
2667
2668
2669
        ret = Dataset(
            None,
            reference=self,
            feature_name=self.feature_name,
            categorical_feature=self.categorical_feature,
            params=params,
            free_raw_data=self.free_raw_data,
        )
wxchan's avatar
wxchan committed
2670
        ret._predictor = self._predictor
2671
        ret.pandas_categorical = self.pandas_categorical
2672
        ret.used_indices = sorted(used_indices)
wxchan's avatar
wxchan committed
2673
2674
        return ret

2675
    def save_binary(self, filename: Union[str, Path]) -> "Dataset":
2676
        """Save Dataset to a binary file.
wxchan's avatar
wxchan committed
2677

2678
2679
2680
2681
2682
        .. note::

            Please note that `init_score` is not saved in binary file.
            If you need it, please set it again after loading Dataset.

wxchan's avatar
wxchan committed
2683
2684
        Parameters
        ----------
2685
        filename : str or pathlib.Path
wxchan's avatar
wxchan committed
2686
            Name of the output file.
Nikita Titov's avatar
Nikita Titov committed
2687
2688
2689
2690
2691

        Returns
        -------
        self : Dataset
            Returns self.
wxchan's avatar
wxchan committed
2692
        """
2693
2694
2695
2696
2697
2698
        _safe_call(
            _LIB.LGBM_DatasetSaveBinary(
                self.construct()._handle,
                _c_str(str(filename)),
            )
        )
Nikita Titov's avatar
Nikita Titov committed
2699
        return self
wxchan's avatar
wxchan committed
2700

2701
    def _update_params(self, params: Optional[Dict[str, Any]]) -> "Dataset":
2702
2703
        if not params:
            return self
2704
        params = deepcopy(params)
2705

2706
        def update() -> None:
2707
2708
2709
            if not self.params:
                self.params = params
            else:
2710
                self._params_back_up = deepcopy(self.params)
2711
2712
                self.params.update(params)

2713
        if self._handle is None:
2714
2715
2716
            update()
        elif params is not None:
            ret = _LIB.LGBM_DatasetUpdateParamChecking(
2717
                _c_str(_param_dict_to_str(self.params)),
2718
2719
                _c_str(_param_dict_to_str(params)),
            )
2720
2721
2722
2723
2724
2725
            if ret != 0:
                # could be updated if data is not freed
                if self.data is not None:
                    update()
                    self._free_handle()
                else:
2726
                    raise LightGBMError(_LIB.LGBM_GetLastError().decode("utf-8"))
Nikita Titov's avatar
Nikita Titov committed
2727
        return self
wxchan's avatar
wxchan committed
2728

2729
    def _reverse_update_params(self) -> "Dataset":
2730
        if self._handle is None:
2731
2732
            self.params = deepcopy(self._params_back_up)
            self._params_back_up = None
Nikita Titov's avatar
Nikita Titov committed
2733
        return self
2734

2735
2736
2737
    def set_field(
        self,
        field_name: str,
2738
        data: Optional[_LGBM_SetFieldType],
2739
    ) -> "Dataset":
wxchan's avatar
wxchan committed
2740
        """Set property into the Dataset.
wxchan's avatar
wxchan committed
2741
2742
2743

        Parameters
        ----------
2744
        field_name : str
2745
            The field name of the information.
2746
        data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray or None
2747
            The data to be set.
Nikita Titov's avatar
Nikita Titov committed
2748
2749
2750
2751
2752

        Returns
        -------
        self : Dataset
            Dataset with set property.
wxchan's avatar
wxchan committed
2753
        """
2754
        if self._handle is None:
2755
            raise Exception(f"Cannot set {field_name} before construct dataset")
wxchan's avatar
wxchan committed
2756
        if data is None:
2757
            # set to None
2758
2759
2760
2761
2762
2763
2764
2765
2766
            _safe_call(
                _LIB.LGBM_DatasetSetField(
                    self._handle,
                    _c_str(field_name),
                    None,
                    ctypes.c_int(0),
                    ctypes.c_int(_FIELD_TYPE_MAPPER[field_name]),
                )
            )
Nikita Titov's avatar
Nikita Titov committed
2767
            return self
2768
2769

        # If the data is a arrow data, we can just pass it to C
2770
2771
2772
2773
2774
2775
        if _is_pyarrow_array(data) or _is_pyarrow_table(data):
            # If a table is being passed, we concatenate the columns. This is only valid for
            # 'init_score'.
            if _is_pyarrow_table(data):
                if field_name != "init_score":
                    raise ValueError(f"pyarrow tables are not supported for field '{field_name}'")
2776
2777
2778
2779
2780
2781
2782
                data = pa_chunked_array(
                    [
                        chunk
                        for array in data.columns  # type: ignore
                        for chunk in array.chunks
                    ]
                )
2783

2784
            c_array = _export_arrow_to_c(data)
2785
2786
2787
2788
2789
2790
2791
2792
2793
            _safe_call(
                _LIB.LGBM_DatasetSetFieldFromArrow(
                    self._handle,
                    _c_str(field_name),
                    ctypes.c_int64(c_array.n_chunks),
                    ctypes.c_void_p(c_array.chunks_ptr),
                    ctypes.c_void_p(c_array.schema_ptr),
                )
            )
2794
2795
2796
            self.version += 1
            return self

2797
        dtype: "np.typing.DTypeLike"
2798
        if field_name == "init_score":
Guolin Ke's avatar
Guolin Ke committed
2799
            dtype = np.float64
2800
            if _is_1d_collection(data):
2801
                data = _list_to_1d_numpy(data, dtype=dtype, name=field_name)
2802
            elif _is_2d_collection(data):
2803
                data = _data_to_2d_numpy(data, dtype=dtype, name=field_name)
2804
                data = data.ravel(order="F")
2805
2806
            else:
                raise TypeError(
2807
2808
                    "init_score must be list, numpy 1-D array or pandas Series.\n"
                    "In multiclass classification init_score can also be a list of lists, numpy 2-D array or pandas DataFrame."
2809
2810
                )
        else:
2811
            if field_name in {"group", "position"}:
2812
2813
2814
                dtype = np.int32
            else:
                dtype = np.float32
2815
            data = _list_to_1d_numpy(data, dtype=dtype, name=field_name)
2816

2817
        ptr_data: Union[_ctypes_float_ptr, _ctypes_int_ptr]
2818
        if data.dtype == np.float32 or data.dtype == np.float64:
2819
            ptr_data, type_data, _ = _c_float_array(data)
wxchan's avatar
wxchan committed
2820
        elif data.dtype == np.int32:
2821
            ptr_data, type_data, _ = _c_int_array(data)
wxchan's avatar
wxchan committed
2822
        else:
2823
            raise TypeError(f"Expected np.float32/64 or np.int32, met type({data.dtype})")
2824
        if type_data != _FIELD_TYPE_MAPPER[field_name]:
2825
            raise TypeError("Input type error for set_field")
2826
2827
2828
2829
2830
2831
2832
2833
2834
        _safe_call(
            _LIB.LGBM_DatasetSetField(
                self._handle,
                _c_str(field_name),
                ptr_data,
                ctypes.c_int(len(data)),
                ctypes.c_int(type_data),
            )
        )
2835
        self.version += 1
Nikita Titov's avatar
Nikita Titov committed
2836
        return self
wxchan's avatar
wxchan committed
2837

2838
    def get_field(self, field_name: str) -> Optional[np.ndarray]:
wxchan's avatar
wxchan committed
2839
        """Get property from the Dataset.
wxchan's avatar
wxchan committed
2840

2841
2842
2843
2844
2845
2846
        Can only be run on a constructed Dataset.

        Unlike ``get_group()``, ``get_init_score()``, ``get_label()``, ``get_position()``, and ``get_weight()``,
        this method ignores any raw data passed into ``lgb.Dataset()`` on the Python side, and will only read
        data from the constructed C++ ``Dataset`` object.

wxchan's avatar
wxchan committed
2847
2848
        Parameters
        ----------
2849
        field_name : str
2850
            The field name of the information.
wxchan's avatar
wxchan committed
2851
2852
2853

        Returns
        -------
2854
        info : numpy array or None
2855
            A numpy array with information from the Dataset.
Guolin Ke's avatar
Guolin Ke committed
2856
        """
2857
        if self._handle is None:
2858
            raise Exception(f"Cannot get {field_name} before construct Dataset")
2859
2860
        tmp_out_len = ctypes.c_int(0)
        out_type = ctypes.c_int(0)
wxchan's avatar
wxchan committed
2861
        ret = ctypes.POINTER(ctypes.c_void_p)()
2862
2863
2864
2865
2866
2867
2868
2869
2870
        _safe_call(
            _LIB.LGBM_DatasetGetField(
                self._handle,
                _c_str(field_name),
                ctypes.byref(tmp_out_len),
                ctypes.byref(ret),
                ctypes.byref(out_type),
            )
        )
2871
        if out_type.value != _FIELD_TYPE_MAPPER[field_name]:
wxchan's avatar
wxchan committed
2872
2873
2874
            raise TypeError("Return type error for get_field")
        if tmp_out_len.value == 0:
            return None
2875
        if out_type.value == _C_API_DTYPE_INT32:
2876
2877
            arr = _cint32_array_to_numpy(
                cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)),
2878
                length=tmp_out_len.value,
2879
            )
2880
        elif out_type.value == _C_API_DTYPE_FLOAT32:
2881
2882
            arr = _cfloat32_array_to_numpy(
                cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)),
2883
                length=tmp_out_len.value,
2884
            )
2885
        elif out_type.value == _C_API_DTYPE_FLOAT64:
2886
2887
            arr = _cfloat64_array_to_numpy(
                cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)),
2888
                length=tmp_out_len.value,
2889
            )
2890
        else:
wxchan's avatar
wxchan committed
2891
            raise TypeError("Unknown type")
2892
        if field_name == "init_score":
2893
2894
2895
            num_data = self.num_data()
            num_classes = arr.size // num_data
            if num_classes > 1:
2896
                arr = arr.reshape((num_data, num_classes), order="F")
2897
        return arr
Guolin Ke's avatar
Guolin Ke committed
2898

2899
2900
    def set_categorical_feature(
        self,
2901
        categorical_feature: _LGBM_CategoricalFeatureConfiguration,
2902
    ) -> "Dataset":
2903
        """Set categorical features.
2904
2905
2906

        Parameters
        ----------
2907
        categorical_feature : list of str or int, or 'auto'
2908
            Names or indices of categorical features.
Nikita Titov's avatar
Nikita Titov committed
2909
2910
2911
2912
2913

        Returns
        -------
        self : Dataset
            Dataset with set categorical features.
2914
2915
        """
        if self.categorical_feature == categorical_feature:
Nikita Titov's avatar
Nikita Titov committed
2916
            return self
2917
        if self.data is not None:
2918
2919
            if self.categorical_feature is None:
                self.categorical_feature = categorical_feature
Nikita Titov's avatar
Nikita Titov committed
2920
                return self._free_handle()
2921
            elif categorical_feature == "auto":
Nikita Titov's avatar
Nikita Titov committed
2922
                return self
2923
            else:
2924
2925
2926
2927
2928
                if self.categorical_feature != "auto":
                    _log_warning(
                        "categorical_feature in Dataset is overridden.\n"
                        f"New categorical_feature is {list(categorical_feature)}"
                    )
2929
                self.categorical_feature = categorical_feature
Nikita Titov's avatar
Nikita Titov committed
2930
                return self._free_handle()
2931
        else:
2932
2933
2934
2935
            raise LightGBMError(
                "Cannot set categorical feature after freed raw data, "
                "set free_raw_data=False when construct Dataset to avoid this."
            )
2936

2937
2938
    def _set_predictor(
        self,
2939
        predictor: Optional[_InnerPredictor],
2940
    ) -> "Dataset":
2941
2942
2943
2944
        """Set predictor for continued training.

        It is not recommended for user to call this function.
        Please use init_model argument in engine.train() or engine.cv() instead.
Guolin Ke's avatar
Guolin Ke committed
2945
        """
2946
        if predictor is None and self._predictor is None:
Nikita Titov's avatar
Nikita Titov committed
2947
            return self
2948
        elif isinstance(predictor, _InnerPredictor) and isinstance(self._predictor, _InnerPredictor):
2949
2950
2951
            if (predictor == self._predictor) and (
                predictor.current_iteration() == self._predictor.current_iteration()
            ):
2952
                return self
2953
        if self._handle is None:
Guolin Ke's avatar
Guolin Ke committed
2954
            self._predictor = predictor
2955
2956
        elif self.data is not None:
            self._predictor = predictor
2957
2958
2959
            self._set_init_score_by_predictor(
                predictor=self._predictor,
                data=self.data,
2960
                used_indices=None,
2961
            )
2962
2963
        elif self.used_indices is not None and self.reference is not None and self.reference.data is not None:
            self._predictor = predictor
2964
2965
2966
            self._set_init_score_by_predictor(
                predictor=self._predictor,
                data=self.reference.data,
2967
                used_indices=self.used_indices,
2968
            )
Guolin Ke's avatar
Guolin Ke committed
2969
        else:
2970
2971
2972
2973
            raise LightGBMError(
                "Cannot set predictor after freed raw data, "
                "set free_raw_data=False when construct Dataset to avoid this."
            )
2974
        return self
Guolin Ke's avatar
Guolin Ke committed
2975

2976
    def set_reference(self, reference: "Dataset") -> "Dataset":
2977
        """Set reference Dataset.
Guolin Ke's avatar
Guolin Ke committed
2978
2979
2980
2981

        Parameters
        ----------
        reference : Dataset
2982
            Reference that is used as a template to construct the current Dataset.
Nikita Titov's avatar
Nikita Titov committed
2983
2984
2985
2986
2987

        Returns
        -------
        self : Dataset
            Dataset with set reference.
Guolin Ke's avatar
Guolin Ke committed
2988
        """
2989
2990
2991
        self.set_categorical_feature(reference.categorical_feature).set_feature_name(
            reference.feature_name
        )._set_predictor(reference._predictor)
2992
        # we're done if self and reference share a common upstream reference
2993
        if self.get_ref_chain().intersection(reference.get_ref_chain()):
Nikita Titov's avatar
Nikita Titov committed
2994
            return self
Guolin Ke's avatar
Guolin Ke committed
2995
2996
        if self.data is not None:
            self.reference = reference
Nikita Titov's avatar
Nikita Titov committed
2997
            return self._free_handle()
Guolin Ke's avatar
Guolin Ke committed
2998
        else:
2999
3000
3001
3002
            raise LightGBMError(
                "Cannot set reference after freed raw data, "
                "set free_raw_data=False when construct Dataset to avoid this."
            )
Guolin Ke's avatar
Guolin Ke committed
3003

3004
    def set_feature_name(self, feature_name: _LGBM_FeatureNameConfiguration) -> "Dataset":
3005
        """Set feature name.
Guolin Ke's avatar
Guolin Ke committed
3006
3007
3008

        Parameters
        ----------
3009
        feature_name : list of str
3010
            Feature names.
Nikita Titov's avatar
Nikita Titov committed
3011
3012
3013
3014
3015

        Returns
        -------
        self : Dataset
            Dataset with set feature name.
Guolin Ke's avatar
Guolin Ke committed
3016
        """
3017
        if feature_name != "auto":
3018
            self.feature_name = feature_name
3019
        if self._handle is not None and feature_name is not None and feature_name != "auto":
wxchan's avatar
wxchan committed
3020
            if len(feature_name) != self.num_feature():
3021
3022
3023
                raise ValueError(
                    f"Length of feature_name({len(feature_name)}) and num_feature({self.num_feature()}) don't match"
                )
3024
            c_feature_name = [_c_str(name) for name in feature_name]
3025
3026
3027
3028
3029
3030
3031
            _safe_call(
                _LIB.LGBM_DatasetSetFeatureNames(
                    self._handle,
                    _c_array(ctypes.c_char_p, c_feature_name),
                    ctypes.c_int(len(feature_name)),
                )
            )
Nikita Titov's avatar
Nikita Titov committed
3032
        return self
Guolin Ke's avatar
Guolin Ke committed
3033

3034
    def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
3035
        """Set label of Dataset.
Guolin Ke's avatar
Guolin Ke committed
3036
3037
3038

        Parameters
        ----------
3039
        label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None
3040
            The label information to be set into Dataset.
Nikita Titov's avatar
Nikita Titov committed
3041
3042
3043
3044
3045

        Returns
        -------
        self : Dataset
            Dataset with set label.
Guolin Ke's avatar
Guolin Ke committed
3046
3047
        """
        self.label = label
3048
        if self._handle is not None:
3049
3050
            if isinstance(label, pd_DataFrame):
                if len(label.columns) > 1:
3051
                    raise ValueError("DataFrame for label cannot have multiple columns")
3052
                label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32))
3053
3054
            elif _is_pyarrow_array(label):
                label_array = label
3055
            else:
3056
3057
3058
                label_array = _list_to_1d_numpy(label, dtype=np.float32, name="label")
            self.set_field("label", label_array)
            self.label = self.get_field("label")  # original values can be modified at cpp side
Nikita Titov's avatar
Nikita Titov committed
3059
        return self
Guolin Ke's avatar
Guolin Ke committed
3060

3061
3062
    def set_weight(
        self,
3063
        weight: Optional[_LGBM_WeightType],
3064
    ) -> "Dataset":
3065
        """Set weight of each instance.
Guolin Ke's avatar
Guolin Ke committed
3066
3067
3068

        Parameters
        ----------
3069
        weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
3070
            Weight to be set for each data point. Weights should be non-negative.
Nikita Titov's avatar
Nikita Titov committed
3071
3072
3073
3074
3075

        Returns
        -------
        self : Dataset
            Dataset with set weight.
Guolin Ke's avatar
Guolin Ke committed
3076
        """
3077
3078
3079
3080
3081
3082
3083
        # Check if the weight contains values other than one
        if weight is not None:
            if _is_pyarrow_array(weight):
                if pa_compute.all(pa_compute.equal(weight, 1)).as_py():
                    weight = None
            elif np.all(weight == 1):
                weight = None
Guolin Ke's avatar
Guolin Ke committed
3084
        self.weight = weight
3085
3086

        # Set field
3087
        if self._handle is not None and weight is not None:
3088
            if not _is_pyarrow_array(weight):
3089
3090
3091
                weight = _list_to_1d_numpy(weight, dtype=np.float32, name="weight")
            self.set_field("weight", weight)
            self.weight = self.get_field("weight")  # original values can be modified at cpp side
Nikita Titov's avatar
Nikita Titov committed
3092
        return self
Guolin Ke's avatar
Guolin Ke committed
3093

3094
3095
    def set_init_score(
        self,
3096
        init_score: Optional[_LGBM_InitScoreType],
3097
    ) -> "Dataset":
3098
        """Set init score of Booster to start from.
Guolin Ke's avatar
Guolin Ke committed
3099
3100
3101

        Parameters
        ----------
3102
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None
3103
            Init score for Booster.
Nikita Titov's avatar
Nikita Titov committed
3104
3105
3106
3107
3108

        Returns
        -------
        self : Dataset
            Dataset with set init score.
Guolin Ke's avatar
Guolin Ke committed
3109
3110
        """
        self.init_score = init_score
3111
        if self._handle is not None and init_score is not None:
3112
3113
            self.set_field("init_score", init_score)
            self.init_score = self.get_field("init_score")  # original values can be modified at cpp side
Nikita Titov's avatar
Nikita Titov committed
3114
        return self
Guolin Ke's avatar
Guolin Ke committed
3115

3116
3117
    def set_group(
        self,
3118
        group: Optional[_LGBM_GroupType],
3119
    ) -> "Dataset":
3120
        """Set group size of Dataset (used for ranking).
Guolin Ke's avatar
Guolin Ke committed
3121
3122
3123

        Parameters
        ----------
3124
        group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
3125
3126
3127
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
3128
3129
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
Nikita Titov's avatar
Nikita Titov committed
3130
3131
3132
3133
3134

        Returns
        -------
        self : Dataset
            Dataset with set group.
Guolin Ke's avatar
Guolin Ke committed
3135
3136
        """
        self.group = group
3137
        if self._handle is not None and group is not None:
3138
            if not _is_pyarrow_array(group):
3139
3140
                group = _list_to_1d_numpy(group, dtype=np.int32, name="group")
            self.set_field("group", group)
3141
            # original values can be modified at cpp side
3142
            constructed_group = self.get_field("group")
3143
3144
            if constructed_group is not None:
                self.group = np.diff(constructed_group)
Nikita Titov's avatar
Nikita Titov committed
3145
        return self
Guolin Ke's avatar
Guolin Ke committed
3146

3147
3148
    def set_position(
        self,
3149
        position: Optional[_LGBM_PositionType],
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
    ) -> "Dataset":
        """Set position of Dataset (used for ranking).

        Parameters
        ----------
        position : numpy 1-D array, pandas Series or None, optional (default=None)
            Position of items used in unbiased learning-to-rank task.

        Returns
        -------
        self : Dataset
            Dataset with set position.
        """
        self.position = position
        if self._handle is not None and position is not None:
3165
3166
            position = _list_to_1d_numpy(position, dtype=np.int32, name="position")
            self.set_field("position", position)
3167
3168
        return self

3169
    def get_feature_name(self) -> List[str]:
3170
3171
3172
3173
        """Get the names of columns (features) in the Dataset.

        Returns
        -------
3174
        feature_names : list of str
3175
3176
            The names of columns (features) in the Dataset.
        """
3177
        if self._handle is None:
3178
3179
3180
3181
3182
            raise LightGBMError("Cannot get feature_name before construct dataset")
        num_feature = self.num_feature()
        tmp_out_len = ctypes.c_int(0)
        reserved_string_buffer_size = 255
        required_string_buffer_size = ctypes.c_size_t(0)
3183
        string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
3184
        ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
        _safe_call(
            _LIB.LGBM_DatasetGetFeatureNames(
                self._handle,
                ctypes.c_int(num_feature),
                ctypes.byref(tmp_out_len),
                ctypes.c_size_t(reserved_string_buffer_size),
                ctypes.byref(required_string_buffer_size),
                ptr_string_buffers,
            )
        )
3195
3196
        if num_feature != tmp_out_len.value:
            raise ValueError("Length of feature names doesn't equal with num_feature")
3197
3198
3199
3200
        actual_string_buffer_size = required_string_buffer_size.value
        # if buffer length is not long enough, reallocate buffers
        if reserved_string_buffer_size < actual_string_buffer_size:
            string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
3201
            ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
            _safe_call(
                _LIB.LGBM_DatasetGetFeatureNames(
                    self._handle,
                    ctypes.c_int(num_feature),
                    ctypes.byref(tmp_out_len),
                    ctypes.c_size_t(actual_string_buffer_size),
                    ctypes.byref(required_string_buffer_size),
                    ptr_string_buffers,
                )
            )
        return [string_buffers[i].value.decode("utf-8") for i in range(num_feature)]
3213

3214
    def get_label(self) -> Optional[_LGBM_LabelType]:
3215
        """Get the label of the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3216
3217
3218

        Returns
        -------
3219
        label : list, numpy 1-D array, pandas Series / one-column DataFrame or None
3220
            The label information from the Dataset.
3221
            For a constructed ``Dataset``, this will only return a numpy array.
Guolin Ke's avatar
Guolin Ke committed
3222
        """
3223
        if self.label is None:
3224
            self.label = self.get_field("label")
Guolin Ke's avatar
Guolin Ke committed
3225
3226
        return self.label

3227
    def get_weight(self) -> Optional[_LGBM_WeightType]:
3228
        """Get the weight of the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3229
3230
3231

        Returns
        -------
3232
        weight : list, numpy 1-D array, pandas Series or None
3233
            Weight for each data point from the Dataset. Weights should be non-negative.
3234
            For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
Guolin Ke's avatar
Guolin Ke committed
3235
        """
3236
        if self.weight is None:
3237
            self.weight = self.get_field("weight")
Guolin Ke's avatar
Guolin Ke committed
3238
3239
        return self.weight

3240
    def get_init_score(self) -> Optional[_LGBM_InitScoreType]:
3241
        """Get the initial score of the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3242
3243
3244

        Returns
        -------
3245
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None
3246
            Init score of Booster.
3247
            For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
Guolin Ke's avatar
Guolin Ke committed
3248
        """
3249
        if self.init_score is None:
3250
            self.init_score = self.get_field("init_score")
Guolin Ke's avatar
Guolin Ke committed
3251
3252
        return self.init_score

3253
    def get_data(self) -> Optional[_LGBM_TrainDataType]:
3254
3255
3256
3257
        """Get the raw data of the Dataset.

        Returns
        -------
3258
        data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array or None
3259
3260
            Raw data used in the Dataset construction.
        """
3261
        if self._handle is None:
3262
            raise Exception("Cannot get data before construct Dataset")
3263
        if self._need_slice and self.used_indices is not None and self.reference is not None:
Guolin Ke's avatar
Guolin Ke committed
3264
3265
            self.data = self.reference.data
            if self.data is not None:
3266
                if isinstance(self.data, (np.ndarray, scipy.sparse.spmatrix)):
Guolin Ke's avatar
Guolin Ke committed
3267
                    self.data = self.data[self.used_indices, :]
3268
                elif isinstance(self.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3269
                    self.data = self.data.iloc[self.used_indices].copy()
3270
                elif isinstance(self.data, dt_DataTable):
Guolin Ke's avatar
Guolin Ke committed
3271
                    self.data = self.data[self.used_indices, :]
3272
3273
                elif isinstance(self.data, Sequence):
                    self.data = self.data[self.used_indices]
3274
                elif _is_list_of_sequences(self.data) and len(self.data) > 0:
3275
                    self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
Guolin Ke's avatar
Guolin Ke committed
3276
                else:
3277
3278
3279
                    _log_warning(
                        f"Cannot subset {type(self.data).__name__} type of raw data.\n" "Returning original raw data"
                    )
3280
            self._need_slice = False
Guolin Ke's avatar
Guolin Ke committed
3281
        if self.data is None:
3282
3283
3284
3285
            raise LightGBMError(
                "Cannot call `get_data` after freed raw data, "
                "set free_raw_data=False when construct Dataset to avoid this."
            )
3286
3287
        return self.data

3288
    def get_group(self) -> Optional[_LGBM_GroupType]:
3289
        """Get the group of the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3290
3291
3292

        Returns
        -------
3293
        group : list, numpy 1-D array, pandas Series or None
3294
3295
3296
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
3297
3298
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
3299
            For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
Guolin Ke's avatar
Guolin Ke committed
3300
        """
3301
        if self.group is None:
3302
            self.group = self.get_field("group")
Guolin Ke's avatar
Guolin Ke committed
3303
3304
            if self.group is not None:
                # group data from LightGBM is boundaries data, need to convert to group size
Nikita Titov's avatar
Nikita Titov committed
3305
                self.group = np.diff(self.group)
Guolin Ke's avatar
Guolin Ke committed
3306
3307
        return self.group

3308
    def get_position(self) -> Optional[_LGBM_PositionType]:
3309
3310
3311
3312
        """Get the position of the Dataset.

        Returns
        -------
3313
        position : numpy 1-D array, pandas Series or None
3314
            Position of items used in unbiased learning-to-rank task.
3315
            For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
3316
3317
        """
        if self.position is None:
3318
            self.position = self.get_field("position")
3319
3320
        return self.position

3321
    def num_data(self) -> int:
3322
        """Get the number of rows in the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3323
3324
3325

        Returns
        -------
3326
3327
        number_of_rows : int
            The number of rows in the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3328
        """
3329
        if self._handle is not None:
3330
            ret = ctypes.c_int(0)
3331
3332
3333
3334
3335
3336
            _safe_call(
                _LIB.LGBM_DatasetGetNumData(
                    self._handle,
                    ctypes.byref(ret),
                )
            )
wxchan's avatar
wxchan committed
3337
            return ret.value
Guolin Ke's avatar
Guolin Ke committed
3338
        else:
3339
            raise LightGBMError("Cannot get num_data before construct dataset")
Guolin Ke's avatar
Guolin Ke committed
3340

3341
    def num_feature(self) -> int:
3342
        """Get the number of columns (features) in the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3343
3344
3345

        Returns
        -------
3346
3347
        number_of_columns : int
            The number of columns (features) in the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3348
        """
3349
        if self._handle is not None:
3350
            ret = ctypes.c_int(0)
3351
3352
3353
3354
3355
3356
            _safe_call(
                _LIB.LGBM_DatasetGetNumFeature(
                    self._handle,
                    ctypes.byref(ret),
                )
            )
wxchan's avatar
wxchan committed
3357
            return ret.value
Guolin Ke's avatar
Guolin Ke committed
3358
        else:
3359
            raise LightGBMError("Cannot get num_feature before construct dataset")
Guolin Ke's avatar
Guolin Ke committed
3360

3361
    def feature_num_bin(self, feature: Union[int, str]) -> int:
3362
3363
        """Get the number of bins for a feature.

3364
3365
        .. versionadded:: 4.0.0

3366
3367
        Parameters
        ----------
3368
3369
        feature : int or str
            Index or name of the feature.
3370
3371
3372
3373
3374
3375

        Returns
        -------
        number_of_bins : int
            The number of constructed bins for the feature in the Dataset.
        """
3376
        if self._handle is not None:
3377
            if isinstance(feature, str):
3378
3379
3380
                feature_index = self.feature_name.index(feature)
            else:
                feature_index = feature
3381
            ret = ctypes.c_int(0)
3382
3383
3384
3385
3386
3387
3388
            _safe_call(
                _LIB.LGBM_DatasetGetFeatureNumBin(
                    self._handle,
                    ctypes.c_int(feature_index),
                    ctypes.byref(ret),
                )
            )
3389
3390
3391
3392
            return ret.value
        else:
            raise LightGBMError("Cannot get feature_num_bin before construct dataset")

3393
    def get_ref_chain(self, ref_limit: int = 100) -> Set["Dataset"]:
3394
3395
3396
3397
3398
        """Get a chain of Dataset objects.

        Starts with r, then goes to r.reference (if exists),
        then to r.reference.reference, etc.
        until we hit ``ref_limit`` or a reference loop.
3399
3400
3401
3402
3403

        Parameters
        ----------
        ref_limit : int, optional (default=100)
            The limit number of references.
3404
3405
3406

        Returns
        -------
3407
3408
3409
        ref_chain : set of Dataset
            Chain of references of the Datasets.
        """
3410
        head = self
3411
        ref_chain: Set[Dataset] = set()
3412
3413
        while len(ref_chain) < ref_limit:
            if isinstance(head, Dataset):
3414
                ref_chain.add(head)
3415
3416
3417
3418
3419
3420
                if (head.reference is not None) and (head.reference not in ref_chain):
                    head = head.reference
                else:
                    break
            else:
                break
Nikita Titov's avatar
Nikita Titov committed
3421
        return ref_chain
3422

3423
    def add_features_from(self, other: "Dataset") -> "Dataset":
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
        """Add features from other Dataset to the current Dataset.

        Both Datasets must be constructed before calling this method.

        Parameters
        ----------
        other : Dataset
            The Dataset to take features from.

        Returns
        -------
        self : Dataset
            Dataset with the new features added.
        """
3438
        if self._handle is None or other._handle is None:
3439
3440
3441
3442
3443
3444
3445
            raise ValueError("Both source and target Datasets must be constructed before adding features")
        _safe_call(
            _LIB.LGBM_DatasetAddFeaturesFrom(
                self._handle,
                other._handle,
            )
        )
Guolin Ke's avatar
Guolin Ke committed
3446
3447
3448
3449
3450
3451
3452
3453
        was_none = self.data is None
        old_self_data_type = type(self.data).__name__
        if other.data is None:
            self.data = None
        elif self.data is not None:
            if isinstance(self.data, np.ndarray):
                if isinstance(other.data, np.ndarray):
                    self.data = np.hstack((self.data, other.data))
3454
                elif isinstance(other.data, scipy.sparse.spmatrix):
Guolin Ke's avatar
Guolin Ke committed
3455
                    self.data = np.hstack((self.data, other.data.toarray()))
3456
                elif isinstance(other.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3457
                    self.data = np.hstack((self.data, other.data.values))
3458
                elif isinstance(other.data, dt_DataTable):
Guolin Ke's avatar
Guolin Ke committed
3459
3460
3461
                    self.data = np.hstack((self.data, other.data.to_numpy()))
                else:
                    self.data = None
3462
            elif isinstance(self.data, scipy.sparse.spmatrix):
Guolin Ke's avatar
Guolin Ke committed
3463
                sparse_format = self.data.getformat()
3464
                if isinstance(other.data, (np.ndarray, scipy.sparse.spmatrix)):
Guolin Ke's avatar
Guolin Ke committed
3465
                    self.data = scipy.sparse.hstack((self.data, other.data), format=sparse_format)
3466
                elif isinstance(other.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3467
                    self.data = scipy.sparse.hstack((self.data, other.data.values), format=sparse_format)
3468
                elif isinstance(other.data, dt_DataTable):
Guolin Ke's avatar
Guolin Ke committed
3469
3470
3471
                    self.data = scipy.sparse.hstack((self.data, other.data.to_numpy()), format=sparse_format)
                else:
                    self.data = None
3472
            elif isinstance(self.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3473
                if not PANDAS_INSTALLED:
3474
3475
3476
3477
3478
                    raise LightGBMError(
                        "Cannot add features to DataFrame type of raw data "
                        "without pandas installed. "
                        "Install pandas and restart your session."
                    )
Guolin Ke's avatar
Guolin Ke committed
3479
                if isinstance(other.data, np.ndarray):
3480
                    self.data = concat((self.data, pd_DataFrame(other.data)), axis=1, ignore_index=True)
3481
                elif isinstance(other.data, scipy.sparse.spmatrix):
3482
                    self.data = concat((self.data, pd_DataFrame(other.data.toarray())), axis=1, ignore_index=True)
3483
                elif isinstance(other.data, pd_DataFrame):
3484
                    self.data = concat((self.data, other.data), axis=1, ignore_index=True)
3485
                elif isinstance(other.data, dt_DataTable):
3486
                    self.data = concat((self.data, pd_DataFrame(other.data.to_numpy())), axis=1, ignore_index=True)
Guolin Ke's avatar
Guolin Ke committed
3487
3488
                else:
                    self.data = None
3489
            elif isinstance(self.data, dt_DataTable):
Guolin Ke's avatar
Guolin Ke committed
3490
                if isinstance(other.data, np.ndarray):
3491
                    self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data)))
3492
                elif isinstance(other.data, scipy.sparse.spmatrix):
3493
3494
3495
3496
3497
                    self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.toarray())))
                elif isinstance(other.data, pd_DataFrame):
                    self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.values)))
                elif isinstance(other.data, dt_DataTable):
                    self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.to_numpy())))
Guolin Ke's avatar
Guolin Ke committed
3498
3499
3500
3501
3502
                else:
                    self.data = None
            else:
                self.data = None
        if self.data is None:
3503
3504
3505
3506
3507
3508
3509
            err_msg = (
                f"Cannot add features from {type(other.data).__name__} type of raw data to "
                f"{old_self_data_type} type of raw data.\n"
            )
            err_msg += (
                "Set free_raw_data=False when construct Dataset to avoid this" if was_none else "Freeing raw data"
            )
3510
            _log_warning(err_msg)
Guolin Ke's avatar
Guolin Ke committed
3511
        self.feature_name = self.get_feature_name()
3512
3513
3514
3515
        _log_warning(
            "Reseting categorical features.\n"
            "You can set new categorical features via ``set_categorical_feature`` method"
        )
Guolin Ke's avatar
Guolin Ke committed
3516
3517
        self.categorical_feature = "auto"
        self.pandas_categorical = None
3518
3519
        return self

3520
    def _dump_text(self, filename: Union[str, Path]) -> "Dataset":
3521
3522
3523
3524
3525
3526
        """Save Dataset to a text file.

        This format cannot be loaded back in by LightGBM, but is useful for debugging purposes.

        Parameters
        ----------
3527
        filename : str or pathlib.Path
3528
3529
3530
3531
3532
3533
3534
            Name of the output file.

        Returns
        -------
        self : Dataset
            Returns self.
        """
3535
3536
3537
3538
3539
3540
        _safe_call(
            _LIB.LGBM_DatasetDumpText(
                self.construct()._handle,
                _c_str(str(filename)),
            )
        )
3541
3542
        return self

wxchan's avatar
wxchan committed
3543

3544
3545
_LGBM_CustomObjectiveFunction = Callable[
    [np.ndarray, Dataset],
3546
    Tuple[np.ndarray, np.ndarray],
3547
]
3548
3549
3550
_LGBM_CustomEvalFunction = Union[
    Callable[
        [np.ndarray, Dataset],
3551
        _LGBM_EvalFunctionResultType,
3552
3553
3554
    ],
    Callable[
        [np.ndarray, Dataset],
3555
3556
        List[_LGBM_EvalFunctionResultType],
    ],
3557
]
3558
3559


3560
class Booster:
3561
    """Booster in LightGBM."""
3562

3563
3564
3565
3566
3567
    def __init__(
        self,
        params: Optional[Dict[str, Any]] = None,
        train_set: Optional[Dataset] = None,
        model_file: Optional[Union[str, Path]] = None,
3568
        model_str: Optional[str] = None,
3569
    ):
3570
        """Initialize the Booster.
wxchan's avatar
wxchan committed
3571
3572
3573

        Parameters
        ----------
Nikita Titov's avatar
Nikita Titov committed
3574
        params : dict or None, optional (default=None)
3575
3576
3577
            Parameters for Booster.
        train_set : Dataset or None, optional (default=None)
            Training dataset.
3578
        model_file : str, pathlib.Path or None, optional (default=None)
wxchan's avatar
wxchan committed
3579
            Path to the model file.
3580
        model_str : str or None, optional (default=None)
3581
            Model will be loaded from this string.
wxchan's avatar
wxchan committed
3582
        """
3583
        self._handle = ctypes.c_void_p()
3584
        self._network = False
wxchan's avatar
wxchan committed
3585
        self.__need_reload_eval_info = True
3586
        self._train_data_name = "training"
3587
        self.__set_objective_to_none = False
wxchan's avatar
wxchan committed
3588
        self.best_iteration = -1
3589
        self.best_score: _LGBM_BoosterBestScoreType = {}
3590
        params = {} if params is None else deepcopy(params)
wxchan's avatar
wxchan committed
3591
        if train_set is not None:
3592
            # Training task
wxchan's avatar
wxchan committed
3593
            if not isinstance(train_set, Dataset):
3594
                raise TypeError(f"Training data should be Dataset instance, met {type(train_set).__name__}")
3595
3596
3597
            params = _choose_param_value(
                main_param_name="machines",
                params=params,
3598
                default_value=None,
3599
3600
3601
3602
3603
3604
3605
            )
            # if "machines" is given, assume user wants to do distributed learning, and set up network
            if params["machines"] is None:
                params.pop("machines", None)
            else:
                machines = params["machines"]
                if isinstance(machines, str):
3606
                    num_machines_from_machine_list = len(machines.split(","))
3607
3608
                elif isinstance(machines, (list, set)):
                    num_machines_from_machine_list = len(machines)
3609
                    machines = ",".join(machines)
3610
3611
3612
3613
3614
3615
                else:
                    raise ValueError("Invalid machines in params.")

                params = _choose_param_value(
                    main_param_name="num_machines",
                    params=params,
3616
                    default_value=num_machines_from_machine_list,
3617
3618
3619
3620
                )
                params = _choose_param_value(
                    main_param_name="local_listen_port",
                    params=params,
3621
                    default_value=12400,
3622
3623
3624
3625
3626
                )
                self.set_network(
                    machines=machines,
                    local_listen_port=params["local_listen_port"],
                    listen_time_out=params.get("time_out", 120),
3627
                    num_machines=params["num_machines"],
3628
                )
3629
            # construct booster object
3630
3631
3632
            train_set.construct()
            # copy the parameters from train_set
            params.update(train_set.get_params())
3633
            params_str = _param_dict_to_str(params)
3634
3635
3636
3637
3638
3639
3640
            _safe_call(
                _LIB.LGBM_BoosterCreate(
                    train_set._handle,
                    _c_str(params_str),
                    ctypes.byref(self._handle),
                )
            )
3641
            # save reference to data
wxchan's avatar
wxchan committed
3642
            self.train_set = train_set
3643
3644
            self.valid_sets: List[Dataset] = []
            self.name_valid_sets: List[str] = []
wxchan's avatar
wxchan committed
3645
            self.__num_dataset = 1
Guolin Ke's avatar
Guolin Ke committed
3646
3647
            self.__init_predictor = train_set._predictor
            if self.__init_predictor is not None:
3648
3649
3650
3651
3652
3653
                _safe_call(
                    _LIB.LGBM_BoosterMerge(
                        self._handle,
                        self.__init_predictor._handle,
                    )
                )
Guolin Ke's avatar
Guolin Ke committed
3654
            out_num_class = ctypes.c_int(0)
3655
3656
3657
3658
3659
3660
            _safe_call(
                _LIB.LGBM_BoosterGetNumClasses(
                    self._handle,
                    ctypes.byref(out_num_class),
                )
            )
wxchan's avatar
wxchan committed
3661
            self.__num_class = out_num_class.value
3662
            # buffer for inner predict
3663
            self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None]
wxchan's avatar
wxchan committed
3664
3665
            self.__is_predicted_cur_iter = [False]
            self.__get_eval_info()
3666
            self.pandas_categorical = train_set.pandas_categorical
3667
            self.train_set_version = train_set.version
wxchan's avatar
wxchan committed
3668
        elif model_file is not None:
3669
            # Prediction task
Guolin Ke's avatar
Guolin Ke committed
3670
            out_num_iterations = ctypes.c_int(0)
3671
3672
3673
3674
3675
3676
3677
            _safe_call(
                _LIB.LGBM_BoosterCreateFromModelfile(
                    _c_str(str(model_file)),
                    ctypes.byref(out_num_iterations),
                    ctypes.byref(self._handle),
                )
            )
Guolin Ke's avatar
Guolin Ke committed
3678
            out_num_class = ctypes.c_int(0)
3679
3680
3681
3682
3683
3684
            _safe_call(
                _LIB.LGBM_BoosterGetNumClasses(
                    self._handle,
                    ctypes.byref(out_num_class),
                )
            )
wxchan's avatar
wxchan committed
3685
            self.__num_class = out_num_class.value
3686
            self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
3687
            if params:
3688
                _log_warning("Ignoring params argument, using parameters from model file.")
3689
            params = self._get_loaded_param()
3690
        elif model_str is not None:
3691
            self.model_from_string(model_str)
wxchan's avatar
wxchan committed
3692
        else:
3693
3694
3695
            raise TypeError(
                "Need at least one training dataset or model file or model string " "to create Booster instance"
            )
3696
        self.params = params
wxchan's avatar
wxchan committed
3697

3698
    def __del__(self) -> None:
3699
        try:
3700
            if self._network:
3701
3702
3703
3704
                self.free_network()
        except AttributeError:
            pass
        try:
3705
3706
            if self._handle is not None:
                _safe_call(_LIB.LGBM_BoosterFree(self._handle))
3707
3708
        except AttributeError:
            pass
wxchan's avatar
wxchan committed
3709

3710
    def __copy__(self) -> "Booster":
wxchan's avatar
wxchan committed
3711
3712
        return self.__deepcopy__(None)

3713
    def __deepcopy__(self, *args: Any, **kwargs: Any) -> "Booster":
3714
        model_str = self.model_to_string(num_iteration=-1)
3715
        return Booster(model_str=model_str)
wxchan's avatar
wxchan committed
3716

3717
    def __getstate__(self) -> Dict[str, Any]:
wxchan's avatar
wxchan committed
3718
        this = self.__dict__.copy()
3719
3720
3721
        handle = this["_handle"]
        this.pop("train_set", None)
        this.pop("valid_sets", None)
wxchan's avatar
wxchan committed
3722
        if handle is not None:
3723
            this["_handle"] = self.model_to_string(num_iteration=-1)
wxchan's avatar
wxchan committed
3724
3725
        return this

3726
    def __setstate__(self, state: Dict[str, Any]) -> None:
3727
        model_str = state.get("_handle", state.get("handle", None))
3728
        if model_str is not None:
wxchan's avatar
wxchan committed
3729
            handle = ctypes.c_void_p()
Guolin Ke's avatar
Guolin Ke committed
3730
            out_num_iterations = ctypes.c_int(0)
3731
3732
3733
3734
3735
3736
3737
3738
            _safe_call(
                _LIB.LGBM_BoosterLoadModelFromString(
                    _c_str(model_str),
                    ctypes.byref(out_num_iterations),
                    ctypes.byref(handle),
                )
            )
            state["_handle"] = handle
wxchan's avatar
wxchan committed
3739
3740
        self.__dict__.update(state)

3741
3742
3743
3744
    def _get_loaded_param(self) -> Dict[str, Any]:
        buffer_len = 1 << 20
        tmp_out_len = ctypes.c_int64(0)
        string_buffer = ctypes.create_string_buffer(buffer_len)
3745
        ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
3746
3747
3748
3749
3750
3751
3752
3753
        _safe_call(
            _LIB.LGBM_BoosterGetLoadedParam(
                self._handle,
                ctypes.c_int64(buffer_len),
                ctypes.byref(tmp_out_len),
                ptr_string_buffer,
            )
        )
3754
3755
3756
3757
        actual_len = tmp_out_len.value
        # if buffer length is not long enough, re-allocate a buffer
        if actual_len > buffer_len:
            string_buffer = ctypes.create_string_buffer(actual_len)
3758
            ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
3759
3760
3761
3762
3763
3764
3765
3766
3767
            _safe_call(
                _LIB.LGBM_BoosterGetLoadedParam(
                    self._handle,
                    ctypes.c_int64(actual_len),
                    ctypes.byref(tmp_out_len),
                    ptr_string_buffer,
                )
            )
        return json.loads(string_buffer.value.decode("utf-8"))
3768

3769
    def free_dataset(self) -> "Booster":
Nikita Titov's avatar
Nikita Titov committed
3770
3771
3772
3773
3774
3775
3776
        """Free Booster's Datasets.

        Returns
        -------
        self : Booster
            Booster without Datasets.
        """
3777
3778
        self.__dict__.pop("train_set", None)
        self.__dict__.pop("valid_sets", None)
3779
        self.__num_dataset = 0
Nikita Titov's avatar
Nikita Titov committed
3780
        return self
wxchan's avatar
wxchan committed
3781

3782
    def _free_buffer(self) -> "Booster":
3783
3784
        self.__inner_predict_buffer = []
        self.__is_predicted_cur_iter = []
Nikita Titov's avatar
Nikita Titov committed
3785
        return self
3786

3787
3788
3789
3790
3791
    def set_network(
        self,
        machines: Union[List[str], Set[str], str],
        local_listen_port: int = 12400,
        listen_time_out: int = 120,
3792
        num_machines: int = 1,
3793
    ) -> "Booster":
3794
3795
3796
3797
        """Set the network configuration.

        Parameters
        ----------
3798
        machines : list, set or str
3799
            Names of machines.
Nikita Titov's avatar
Nikita Titov committed
3800
        local_listen_port : int, optional (default=12400)
3801
            TCP listen port for local machines.
Nikita Titov's avatar
Nikita Titov committed
3802
        listen_time_out : int, optional (default=120)
3803
            Socket time-out in minutes.
Nikita Titov's avatar
Nikita Titov committed
3804
        num_machines : int, optional (default=1)
3805
            The number of machines for distributed learning application.
Nikita Titov's avatar
Nikita Titov committed
3806
3807
3808
3809
3810

        Returns
        -------
        self : Booster
            Booster with set network.
3811
        """
3812
        if isinstance(machines, (list, set)):
3813
3814
3815
3816
3817
3818
3819
3820
3821
            machines = ",".join(machines)
        _safe_call(
            _LIB.LGBM_NetworkInit(
                _c_str(machines),
                ctypes.c_int(local_listen_port),
                ctypes.c_int(listen_time_out),
                ctypes.c_int(num_machines),
            )
        )
3822
        self._network = True
Nikita Titov's avatar
Nikita Titov committed
3823
        return self
3824

3825
    def free_network(self) -> "Booster":
Nikita Titov's avatar
Nikita Titov committed
3826
3827
3828
3829
3830
3831
3832
        """Free Booster's network.

        Returns
        -------
        self : Booster
            Booster with freed network.
        """
3833
        _safe_call(_LIB.LGBM_NetworkFree())
3834
        self._network = False
Nikita Titov's avatar
Nikita Titov committed
3835
        return self
3836

3837
    def trees_to_dataframe(self) -> pd_DataFrame:
3838
3839
        """Parse the fitted model and return in an easy-to-read pandas DataFrame.

3840
3841
3842
3843
        The returned DataFrame has the following columns.

            - ``tree_index`` : int64, which tree a node belongs to. 0-based, so a value of ``6``, for example, means "this node is in the 7th tree".
            - ``node_depth`` : int64, how far a node is from the root of the tree. The root node has a value of ``1``, its direct children are ``2``, etc.
3844
3845
3846
3847
3848
            - ``node_index`` : str, unique identifier for a node.
            - ``left_child`` : str, ``node_index`` of the child node to the left of a split. ``None`` for leaf nodes.
            - ``right_child`` : str, ``node_index`` of the child node to the right of a split. ``None`` for leaf nodes.
            - ``parent_index`` : str, ``node_index`` of this node's parent. ``None`` for the root node.
            - ``split_feature`` : str, name of the feature used for splitting. ``None`` for leaf nodes.
3849
3850
            - ``split_gain`` : float64, gain from adding this split to the tree. ``NaN`` for leaf nodes.
            - ``threshold`` : float64, value of the feature used to decide which side of the split a record will go down. ``NaN`` for leaf nodes.
3851
            - ``decision_type`` : str, logical operator describing how to compare a value to ``threshold``.
3852
3853
              For example, ``split_feature = "Column_10", threshold = 15, decision_type = "<="`` means that
              records where ``Column_10 <= 15`` follow the left side of the split, otherwise follows the right side of the split. ``None`` for leaf nodes.
3854
3855
            - ``missing_direction`` : str, split direction that missing values should go to. ``None`` for leaf nodes.
            - ``missing_type`` : str, describes what types of values are treated as missing.
3856
            - ``value`` : float64, predicted value for this leaf node, multiplied by the learning rate.
3857
            - ``weight`` : float64 or int64, sum of Hessian (second-order derivative of objective), summed over observations that fall in this node.
3858
3859
            - ``count`` : int64, number of records in the training data that fall into this node.

3860
3861
3862
3863
3864
3865
        Returns
        -------
        result : pandas DataFrame
            Returns a pandas DataFrame of the parsed model.
        """
        if not PANDAS_INSTALLED:
3866
3867
3868
3869
            raise LightGBMError(
                "This method cannot be run without pandas installed. "
                "You must install pandas and restart your session to use this method."
            )
3870
3871

        if self.num_trees() == 0:
3872
            raise LightGBMError("There are no trees in this Booster and thus nothing to parse")
3873

3874
        def _is_split_node(tree: Dict[str, Any]) -> bool:
3875
            return "split_index" in tree.keys()
3876

3877
3878
3879
3880
3881
        def create_node_record(
            tree: Dict[str, Any],
            node_depth: int = 1,
            tree_index: Optional[int] = None,
            feature_names: Optional[List[str]] = None,
3882
            parent_node: Optional[str] = None,
3883
3884
3885
        ) -> Dict[str, Any]:
            def _get_node_index(
                tree: Dict[str, Any],
3886
                tree_index: Optional[int],
3887
            ) -> str:
3888
                tree_num = f"{tree_index}-" if tree_index is not None else ""
3889
                is_split = _is_split_node(tree)
3890
                node_type = "S" if is_split else "L"
3891
                # if a single node tree it won't have `leaf_index` so return 0
3892
                node_num = tree.get("split_index" if is_split else "leaf_index", 0)
3893
                return f"{tree_num}{node_type}{node_num}"
3894

3895
3896
            def _get_split_feature(
                tree: Dict[str, Any],
3897
                feature_names: Optional[List[str]],
3898
            ) -> Optional[str]:
3899
3900
                if _is_split_node(tree):
                    if feature_names is not None:
3901
                        feature_name = feature_names[tree["split_feature"]]
3902
                    else:
3903
                        feature_name = tree["split_feature"]
3904
3905
3906
3907
                else:
                    feature_name = None
                return feature_name

3908
            def _is_single_node_tree(tree: Dict[str, Any]) -> bool:
3909
                return set(tree.keys()) == {"leaf_value"}
3910
3911

            # Create the node record, and populate universal data members
3912
            node: Dict[str, Union[int, str, None]] = OrderedDict()
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924
3925
3926
3927
            node["tree_index"] = tree_index
            node["node_depth"] = node_depth
            node["node_index"] = _get_node_index(tree, tree_index)
            node["left_child"] = None
            node["right_child"] = None
            node["parent_index"] = parent_node
            node["split_feature"] = _get_split_feature(tree, feature_names)
            node["split_gain"] = None
            node["threshold"] = None
            node["decision_type"] = None
            node["missing_direction"] = None
            node["missing_type"] = None
            node["value"] = None
            node["weight"] = None
            node["count"] = None
3928
3929
3930

            # Update values to reflect node type (leaf or split)
            if _is_split_node(tree):
3931
3932
3933
3934
3935
3936
3937
3938
3939
3940
                node["left_child"] = _get_node_index(tree["left_child"], tree_index)
                node["right_child"] = _get_node_index(tree["right_child"], tree_index)
                node["split_gain"] = tree["split_gain"]
                node["threshold"] = tree["threshold"]
                node["decision_type"] = tree["decision_type"]
                node["missing_direction"] = "left" if tree["default_left"] else "right"
                node["missing_type"] = tree["missing_type"]
                node["value"] = tree["internal_value"]
                node["weight"] = tree["internal_weight"]
                node["count"] = tree["internal_count"]
3941
            else:
3942
                node["value"] = tree["leaf_value"]
3943
                if not _is_single_node_tree(tree):
3944
3945
                    node["weight"] = tree["leaf_weight"]
                    node["count"] = tree["leaf_count"]
3946
3947
3948

            return node

3949
3950
3951
3952
3953
        def tree_dict_to_node_list(
            tree: Dict[str, Any],
            node_depth: int = 1,
            tree_index: Optional[int] = None,
            feature_names: Optional[List[str]] = None,
3954
            parent_node: Optional[str] = None,
3955
        ) -> List[Dict[str, Any]]:
3956
3957
3958
3959
3960
3961
3962
            node = create_node_record(
                tree=tree,
                node_depth=node_depth,
                tree_index=tree_index,
                feature_names=feature_names,
                parent_node=parent_node,
            )
3963
3964
3965
3966
3967

            res = [node]

            if _is_split_node(tree):
                # traverse the next level of the tree
3968
                children = ["left_child", "right_child"]
3969
3970
                for child in children:
                    subtree_list = tree_dict_to_node_list(
3971
                        tree=tree[child],
3972
3973
3974
                        node_depth=node_depth + 1,
                        tree_index=tree_index,
                        feature_names=feature_names,
3975
                        parent_node=node["node_index"],
3976
                    )
3977
3978
3979
3980
3981
3982
                    # In tree format, "subtree_list" is a list of node records (dicts),
                    # and we add node to the list.
                    res.extend(subtree_list)
            return res

        model_dict = self.dump_model()
3983
        feature_names = model_dict["feature_names"]
3984
        model_list = []
3985
3986
3987
3988
3989
3990
        for tree in model_dict["tree_info"]:
            model_list.extend(
                tree_dict_to_node_list(
                    tree=tree["tree_structure"], tree_index=tree["tree_index"], feature_names=feature_names
                )
            )
3991

3992
        return pd_DataFrame(model_list, columns=model_list[0].keys())
3993

3994
    def set_train_data_name(self, name: str) -> "Booster":
3995
3996
3997
3998
        """Set the name to the training Dataset.

        Parameters
        ----------
3999
        name : str
Nikita Titov's avatar
Nikita Titov committed
4000
4001
4002
4003
4004
4005
            Name for the training Dataset.

        Returns
        -------
        self : Booster
            Booster with set training Dataset name.
4006
        """
4007
        self._train_data_name = name
Nikita Titov's avatar
Nikita Titov committed
4008
        return self
wxchan's avatar
wxchan committed
4009

4010
    def add_valid(self, data: Dataset, name: str) -> "Booster":
4011
        """Add validation data.
wxchan's avatar
wxchan committed
4012
4013
4014
4015

        Parameters
        ----------
        data : Dataset
4016
            Validation data.
4017
        name : str
4018
            Name of validation data.
Nikita Titov's avatar
Nikita Titov committed
4019
4020
4021
4022
4023

        Returns
        -------
        self : Booster
            Booster with set validation data.
wxchan's avatar
wxchan committed
4024
        """
Guolin Ke's avatar
Guolin Ke committed
4025
        if not isinstance(data, Dataset):
4026
            raise TypeError(f"Validation data should be Dataset instance, met {type(data).__name__}")
Guolin Ke's avatar
Guolin Ke committed
4027
        if data._predictor is not self.__init_predictor:
4028
4029
4030
4031
4032
4033
4034
            raise LightGBMError("Add validation data failed, " "you should use same predictor for these data")
        _safe_call(
            _LIB.LGBM_BoosterAddValidData(
                self._handle,
                data.construct()._handle,
            )
        )
wxchan's avatar
wxchan committed
4035
4036
4037
4038
4039
        self.valid_sets.append(data)
        self.name_valid_sets.append(name)
        self.__num_dataset += 1
        self.__inner_predict_buffer.append(None)
        self.__is_predicted_cur_iter.append(False)
Nikita Titov's avatar
Nikita Titov committed
4040
        return self
wxchan's avatar
wxchan committed
4041

4042
    def reset_parameter(self, params: Dict[str, Any]) -> "Booster":
4043
        """Reset parameters of Booster.
wxchan's avatar
wxchan committed
4044
4045
4046
4047

        Parameters
        ----------
        params : dict
4048
            New parameters for Booster.
Nikita Titov's avatar
Nikita Titov committed
4049
4050
4051
4052
4053

        Returns
        -------
        self : Booster
            Booster with new parameters.
wxchan's avatar
wxchan committed
4054
        """
4055
        params_str = _param_dict_to_str(params)
wxchan's avatar
wxchan committed
4056
        if params_str:
4057
4058
4059
4060
4061
4062
            _safe_call(
                _LIB.LGBM_BoosterResetParameter(
                    self._handle,
                    _c_str(params_str),
                )
            )
Guolin Ke's avatar
Guolin Ke committed
4063
        self.params.update(params)
Nikita Titov's avatar
Nikita Titov committed
4064
        return self
wxchan's avatar
wxchan committed
4065

4066
4067
4068
    def update(
        self,
        train_set: Optional[Dataset] = None,
4069
        fobj: Optional[_LGBM_CustomObjectiveFunction] = None,
4070
    ) -> bool:
Nikita Titov's avatar
Nikita Titov committed
4071
        """Update Booster for one iteration.
4072

wxchan's avatar
wxchan committed
4073
4074
        Parameters
        ----------
4075
4076
4077
4078
        train_set : Dataset or None, optional (default=None)
            Training data.
            If None, last training data is used.
        fobj : callable or None, optional (default=None)
wxchan's avatar
wxchan committed
4079
            Customized objective function.
4080
4081
4082
            Should accept two parameters: preds, train_data,
            and return (grad, hess).

4083
                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
4084
                    The predicted values.
4085
4086
                    Predicted values are returned before any transformation,
                    e.g. they are raw margin instead of probability of positive class for binary task.
4087
4088
                train_data : Dataset
                    The training dataset.
4089
                grad : numpy 1-D array or numpy 2-D array (for multi-class task)
4090
4091
                    The value of the first order derivative (gradient) of the loss
                    with respect to the elements of preds for each sample point.
4092
                hess : numpy 1-D array or numpy 2-D array (for multi-class task)
4093
4094
                    The value of the second order derivative (Hessian) of the loss
                    with respect to the elements of preds for each sample point.
wxchan's avatar
wxchan committed
4095

4096
            For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
4097
            and grad and hess should be returned in the same format.
4098

wxchan's avatar
wxchan committed
4099
4100
        Returns
        -------
4101
4102
        is_finished : bool
            Whether the update was successfully finished.
wxchan's avatar
wxchan committed
4103
        """
4104
        # need reset training data
4105
4106
4107
4108
4109
4110
        if train_set is None and self.train_set_version != self.train_set.version:
            train_set = self.train_set
            is_the_same_train_set = False
        else:
            is_the_same_train_set = train_set is self.train_set and self.train_set_version == train_set.version
        if train_set is not None and not is_the_same_train_set:
Guolin Ke's avatar
Guolin Ke committed
4111
            if not isinstance(train_set, Dataset):
4112
                raise TypeError(f"Training data should be Dataset instance, met {type(train_set).__name__}")
Guolin Ke's avatar
Guolin Ke committed
4113
            if train_set._predictor is not self.__init_predictor:
4114
                raise LightGBMError("Replace training data failed, " "you should use same predictor for these data")
wxchan's avatar
wxchan committed
4115
            self.train_set = train_set
4116
4117
4118
4119
4120
4121
            _safe_call(
                _LIB.LGBM_BoosterResetTrainingData(
                    self._handle,
                    self.train_set.construct()._handle,
                )
            )
wxchan's avatar
wxchan committed
4122
            self.__inner_predict_buffer[0] = None
4123
            self.train_set_version = self.train_set.version
wxchan's avatar
wxchan committed
4124
4125
        is_finished = ctypes.c_int(0)
        if fobj is None:
4126
            if self.__set_objective_to_none:
4127
4128
4129
4130
4131
4132
4133
                raise LightGBMError("Cannot update due to null objective function.")
            _safe_call(
                _LIB.LGBM_BoosterUpdateOneIter(
                    self._handle,
                    ctypes.byref(is_finished),
                )
            )
4134
            self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
wxchan's avatar
wxchan committed
4135
4136
            return is_finished.value == 1
        else:
4137
            if not self.__set_objective_to_none:
Nikita Titov's avatar
Nikita Titov committed
4138
                self.reset_parameter({"objective": "none"}).__set_objective_to_none = True
wxchan's avatar
wxchan committed
4139
4140
4141
            grad, hess = fobj(self.__inner_predict(0), self.train_set)
            return self.__boost(grad, hess)

4142
4143
4144
    def __boost(
        self,
        grad: np.ndarray,
4145
        hess: np.ndarray,
4146
    ) -> bool:
4147
        """Boost Booster for one iteration with customized gradient statistics.
Nikita Titov's avatar
Nikita Titov committed
4148

Nikita Titov's avatar
Nikita Titov committed
4149
4150
        .. note::

4151
4152
            Score is returned before any transformation,
            e.g. it is raw margin instead of probability of positive class for binary task.
4153
            For multi-class task, score are numpy 2-D array of shape = [n_samples, n_classes],
4154
            and grad and hess should be returned in the same format.
4155

wxchan's avatar
wxchan committed
4156
4157
        Parameters
        ----------
4158
        grad : numpy 1-D array or numpy 2-D array (for multi-class task)
4159
4160
            The value of the first order derivative (gradient) of the loss
            with respect to the elements of score for each sample point.
4161
        hess : numpy 1-D array or numpy 2-D array (for multi-class task)
4162
4163
            The value of the second order derivative (Hessian) of the loss
            with respect to the elements of score for each sample point.
wxchan's avatar
wxchan committed
4164
4165
4166

        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4167
4168
        is_finished : bool
            Whether the boost was successfully finished.
wxchan's avatar
wxchan committed
4169
        """
4170
        if self.__num_class > 1:
4171
4172
4173
4174
            grad = grad.ravel(order="F")
            hess = hess.ravel(order="F")
        grad = _list_to_1d_numpy(grad, dtype=np.float32, name="gradient")
        hess = _list_to_1d_numpy(hess, dtype=np.float32, name="hessian")
4175
4176
        assert grad.flags.c_contiguous
        assert hess.flags.c_contiguous
wxchan's avatar
wxchan committed
4177
        if len(grad) != len(hess):
4178
4179
            raise ValueError(f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) don't match")
        num_train_data = self.train_set.num_data()
4180
        if len(grad) != num_train_data * self.__num_class:
4181
4182
4183
            raise ValueError(
                f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) "
                f"don't match training data length ({num_train_data}) * "
4184
                f"number of models per one iteration ({self.__num_class})"
4185
            )
wxchan's avatar
wxchan committed
4186
        is_finished = ctypes.c_int(0)
4187
4188
4189
4190
4191
4192
4193
4194
        _safe_call(
            _LIB.LGBM_BoosterUpdateOneIterCustom(
                self._handle,
                grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
                ctypes.byref(is_finished),
            )
        )
4195
        self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
wxchan's avatar
wxchan committed
4196
4197
        return is_finished.value == 1

4198
    def rollback_one_iter(self) -> "Booster":
Nikita Titov's avatar
Nikita Titov committed
4199
4200
4201
4202
4203
4204
4205
        """Rollback one iteration.

        Returns
        -------
        self : Booster
            Booster with rolled back one iteration.
        """
4206
        _safe_call(_LIB.LGBM_BoosterRollbackOneIter(self._handle))
4207
        self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
Nikita Titov's avatar
Nikita Titov committed
4208
        return self
wxchan's avatar
wxchan committed
4209

4210
    def current_iteration(self) -> int:
4211
4212
4213
4214
4215
4216
4217
        """Get the index of the current iteration.

        Returns
        -------
        cur_iter : int
            The index of the current iteration.
        """
Guolin Ke's avatar
Guolin Ke committed
4218
        out_cur_iter = ctypes.c_int(0)
4219
4220
4221
4222
4223
4224
        _safe_call(
            _LIB.LGBM_BoosterGetCurrentIteration(
                self._handle,
                ctypes.byref(out_cur_iter),
            )
        )
wxchan's avatar
wxchan committed
4225
4226
        return out_cur_iter.value

4227
    def num_model_per_iteration(self) -> int:
4228
4229
4230
4231
4232
4233
4234
4235
        """Get number of models per iteration.

        Returns
        -------
        model_per_iter : int
            The number of models per iteration.
        """
        model_per_iter = ctypes.c_int(0)
4236
4237
4238
4239
4240
4241
        _safe_call(
            _LIB.LGBM_BoosterNumModelPerIteration(
                self._handle,
                ctypes.byref(model_per_iter),
            )
        )
4242
4243
        return model_per_iter.value

4244
    def num_trees(self) -> int:
4245
4246
4247
4248
4249
4250
4251
4252
        """Get number of weak sub-models.

        Returns
        -------
        num_trees : int
            The number of weak sub-models.
        """
        num_trees = ctypes.c_int(0)
4253
4254
4255
4256
4257
4258
        _safe_call(
            _LIB.LGBM_BoosterNumberOfTotalModel(
                self._handle,
                ctypes.byref(num_trees),
            )
        )
4259
4260
        return num_trees.value

4261
    def upper_bound(self) -> float:
4262
4263
4264
4265
        """Get upper bound value of a model.

        Returns
        -------
4266
        upper_bound : float
4267
4268
4269
            Upper bound value of the model.
        """
        ret = ctypes.c_double(0)
4270
4271
4272
4273
4274
4275
        _safe_call(
            _LIB.LGBM_BoosterGetUpperBoundValue(
                self._handle,
                ctypes.byref(ret),
            )
        )
4276
4277
        return ret.value

4278
    def lower_bound(self) -> float:
4279
4280
4281
4282
        """Get lower bound value of a model.

        Returns
        -------
4283
        lower_bound : float
4284
4285
4286
            Lower bound value of the model.
        """
        ret = ctypes.c_double(0)
4287
4288
4289
4290
4291
4292
        _safe_call(
            _LIB.LGBM_BoosterGetLowerBoundValue(
                self._handle,
                ctypes.byref(ret),
            )
        )
4293
4294
        return ret.value

4295
4296
4297
4298
    def eval(
        self,
        data: Dataset,
        name: str,
4299
        feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None,
4300
    ) -> List[_LGBM_BoosterEvalMethodResultType]:
4301
        """Evaluate for data.
wxchan's avatar
wxchan committed
4302
4303
4304

        Parameters
        ----------
4305
4306
        data : Dataset
            Data for the evaluating.
4307
        name : str
4308
            Name of the data.
4309
        feval : callable, list of callable, or None, optional (default=None)
4310
            Customized evaluation function.
4311
            Each evaluation function should accept two parameters: preds, eval_data,
4312
            and return (eval_name, eval_result, is_higher_better) or list of such tuples.
4313

4314
                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
4315
                    The predicted values.
4316
                    For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
4317
                    If custom objective function is used, predicted values are returned before any transformation,
4318
                    e.g. they are raw margin instead of probability of positive class for binary task in this case.
4319
                eval_data : Dataset
4320
                    A ``Dataset`` to evaluate.
4321
                eval_name : str
Andrew Ziem's avatar
Andrew Ziem committed
4322
                    The name of evaluation function (without whitespace).
4323
4324
4325
4326
4327
                eval_result : float
                    The eval result.
                is_higher_better : bool
                    Is eval result higher better, e.g. AUC is ``is_higher_better``.

wxchan's avatar
wxchan committed
4328
4329
        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4330
        result : list
4331
            List with (dataset_name, eval_name, eval_result, is_higher_better) tuples.
wxchan's avatar
wxchan committed
4332
        """
Guolin Ke's avatar
Guolin Ke committed
4333
4334
        if not isinstance(data, Dataset):
            raise TypeError("Can only eval for Dataset instance")
wxchan's avatar
wxchan committed
4335
4336
4337
4338
        data_idx = -1
        if data is self.train_set:
            data_idx = 0
        else:
4339
            for i in range(len(self.valid_sets)):
wxchan's avatar
wxchan committed
4340
4341
4342
                if data is self.valid_sets[i]:
                    data_idx = i + 1
                    break
4343
        # need to push new valid data
wxchan's avatar
wxchan committed
4344
4345
4346
4347
4348
4349
        if data_idx == -1:
            self.add_valid(data, name)
            data_idx = self.__num_dataset - 1

        return self.__inner_eval(name, data_idx, feval)

4350
4351
    def eval_train(
        self,
4352
        feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None,
4353
    ) -> List[_LGBM_BoosterEvalMethodResultType]:
4354
        """Evaluate for training data.
wxchan's avatar
wxchan committed
4355
4356
4357

        Parameters
        ----------
4358
        feval : callable, list of callable, or None, optional (default=None)
4359
            Customized evaluation function.
4360
            Each evaluation function should accept two parameters: preds, eval_data,
4361
            and return (eval_name, eval_result, is_higher_better) or list of such tuples.
4362

4363
                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
4364
                    The predicted values.
4365
                    For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
4366
                    If custom objective function is used, predicted values are returned before any transformation,
4367
                    e.g. they are raw margin instead of probability of positive class for binary task in this case.
Akshita Dixit's avatar
Akshita Dixit committed
4368
                eval_data : Dataset
4369
                    The training dataset.
4370
                eval_name : str
Andrew Ziem's avatar
Andrew Ziem committed
4371
                    The name of evaluation function (without whitespace).
4372
4373
4374
4375
4376
                eval_result : float
                    The eval result.
                is_higher_better : bool
                    Is eval result higher better, e.g. AUC is ``is_higher_better``.

wxchan's avatar
wxchan committed
4377
4378
        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4379
        result : list
4380
            List with (train_dataset_name, eval_name, eval_result, is_higher_better) tuples.
wxchan's avatar
wxchan committed
4381
        """
4382
        return self.__inner_eval(self._train_data_name, 0, feval)
wxchan's avatar
wxchan committed
4383

4384
4385
    def eval_valid(
        self,
4386
        feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None,
4387
    ) -> List[_LGBM_BoosterEvalMethodResultType]:
4388
        """Evaluate for validation data.
wxchan's avatar
wxchan committed
4389
4390
4391

        Parameters
        ----------
4392
        feval : callable, list of callable, or None, optional (default=None)
4393
            Customized evaluation function.
4394
            Each evaluation function should accept two parameters: preds, eval_data,
4395
            and return (eval_name, eval_result, is_higher_better) or list of such tuples.
4396

4397
                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
4398
                    The predicted values.
4399
                    For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
4400
                    If custom objective function is used, predicted values are returned before any transformation,
4401
                    e.g. they are raw margin instead of probability of positive class for binary task in this case.
Akshita Dixit's avatar
Akshita Dixit committed
4402
                eval_data : Dataset
4403
                    The validation dataset.
4404
                eval_name : str
Andrew Ziem's avatar
Andrew Ziem committed
4405
                    The name of evaluation function (without whitespace).
4406
4407
4408
4409
4410
                eval_result : float
                    The eval result.
                is_higher_better : bool
                    Is eval result higher better, e.g. AUC is ``is_higher_better``.

wxchan's avatar
wxchan committed
4411
4412
        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4413
        result : list
4414
            List with (validation_dataset_name, eval_name, eval_result, is_higher_better) tuples.
wxchan's avatar
wxchan committed
4415
        """
4416
4417
4418
4419
4420
        return [
            item
            for i in range(1, self.__num_dataset)
            for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)
        ]
wxchan's avatar
wxchan committed
4421

4422
4423
4424
4425
4426
    def save_model(
        self,
        filename: Union[str, Path],
        num_iteration: Optional[int] = None,
        start_iteration: int = 0,
4427
        importance_type: str = "split",
4428
    ) -> "Booster":
4429
        """Save Booster to file.
wxchan's avatar
wxchan committed
4430
4431
4432

        Parameters
        ----------
4433
        filename : str or pathlib.Path
4434
            Filename to save Booster.
4435
4436
4437
4438
        num_iteration : int or None, optional (default=None)
            Index of the iteration that should be saved.
            If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
            If <= 0, all iterations are saved.
Nikita Titov's avatar
Nikita Titov committed
4439
        start_iteration : int, optional (default=0)
4440
            Start index of the iteration that should be saved.
4441
        importance_type : str, optional (default="split")
4442
4443
4444
            What type of feature importance should be saved.
            If "split", result contains numbers of times the feature is used in a model.
            If "gain", result contains total gains of splits which use the feature.
Nikita Titov's avatar
Nikita Titov committed
4445
4446
4447
4448
4449

        Returns
        -------
        self : Booster
            Returns self.
wxchan's avatar
wxchan committed
4450
        """
4451
        if num_iteration is None:
4452
            num_iteration = self.best_iteration
4453
        importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
4454
4455
4456
4457
4458
4459
4460
4461
4462
        _safe_call(
            _LIB.LGBM_BoosterSaveModel(
                self._handle,
                ctypes.c_int(start_iteration),
                ctypes.c_int(num_iteration),
                ctypes.c_int(importance_type_int),
                _c_str(str(filename)),
            )
        )
4463
        _dump_pandas_categorical(self.pandas_categorical, filename)
Nikita Titov's avatar
Nikita Titov committed
4464
        return self
wxchan's avatar
wxchan committed
4465

4466
4467
4468
    def shuffle_models(
        self,
        start_iteration: int = 0,
4469
        end_iteration: int = -1,
4470
    ) -> "Booster":
4471
        """Shuffle models.
Nikita Titov's avatar
Nikita Titov committed
4472

4473
4474
4475
        Parameters
        ----------
        start_iteration : int, optional (default=0)
4476
            The first iteration that will be shuffled.
4477
4478
        end_iteration : int, optional (default=-1)
            The last iteration that will be shuffled.
4479
            If <= 0, means the last available iteration.
4480

Nikita Titov's avatar
Nikita Titov committed
4481
4482
4483
4484
        Returns
        -------
        self : Booster
            Booster with shuffled models.
4485
        """
4486
4487
4488
4489
4490
4491
4492
        _safe_call(
            _LIB.LGBM_BoosterShuffleModels(
                self._handle,
                ctypes.c_int(start_iteration),
                ctypes.c_int(end_iteration),
            )
        )
Nikita Titov's avatar
Nikita Titov committed
4493
        return self
4494

4495
    def model_from_string(self, model_str: str) -> "Booster":
4496
4497
4498
4499
        """Load Booster from a string.

        Parameters
        ----------
4500
        model_str : str
4501
4502
4503
4504
            Model will be loaded from this string.

        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4505
        self : Booster
4506
4507
            Loaded Booster object.
        """
4508
4509
4510
        # ensure that existing Booster is freed before replacing it
        # with a new one createdfrom file
        _safe_call(_LIB.LGBM_BoosterFree(self._handle))
4511
        self._free_buffer()
4512
        self._handle = ctypes.c_void_p()
4513
        out_num_iterations = ctypes.c_int(0)
4514
4515
4516
4517
4518
4519
4520
        _safe_call(
            _LIB.LGBM_BoosterLoadModelFromString(
                _c_str(model_str),
                ctypes.byref(out_num_iterations),
                ctypes.byref(self._handle),
            )
        )
4521
        out_num_class = ctypes.c_int(0)
4522
4523
4524
4525
4526
4527
        _safe_call(
            _LIB.LGBM_BoosterGetNumClasses(
                self._handle,
                ctypes.byref(out_num_class),
            )
        )
4528
        self.__num_class = out_num_class.value
4529
        self.pandas_categorical = _load_pandas_categorical(model_str=model_str)
4530
4531
        return self

4532
4533
4534
4535
    def model_to_string(
        self,
        num_iteration: Optional[int] = None,
        start_iteration: int = 0,
4536
        importance_type: str = "split",
4537
    ) -> str:
4538
        """Save Booster to string.
4539

4540
4541
4542
4543
4544
4545
        Parameters
        ----------
        num_iteration : int or None, optional (default=None)
            Index of the iteration that should be saved.
            If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
            If <= 0, all iterations are saved.
Nikita Titov's avatar
Nikita Titov committed
4546
        start_iteration : int, optional (default=0)
4547
            Start index of the iteration that should be saved.
4548
        importance_type : str, optional (default="split")
4549
4550
4551
            What type of feature importance should be saved.
            If "split", result contains numbers of times the feature is used in a model.
            If "gain", result contains total gains of splits which use the feature.
4552
4553
4554

        Returns
        -------
4555
        str_repr : str
4556
4557
            String representation of Booster.
        """
4558
        if num_iteration is None:
4559
            num_iteration = self.best_iteration
4560
        importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
4561
        buffer_len = 1 << 20
4562
        tmp_out_len = ctypes.c_int64(0)
4563
        string_buffer = ctypes.create_string_buffer(buffer_len)
4564
        ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
4565
4566
        _safe_call(
            _LIB.LGBM_BoosterSaveModelToString(
4567
                self._handle,
4568
                ctypes.c_int(start_iteration),
4569
                ctypes.c_int(num_iteration),
4570
                ctypes.c_int(importance_type_int),
4571
                ctypes.c_int64(buffer_len),
4572
                ctypes.byref(tmp_out_len),
4573
4574
4575
4576
4577
4578
4579
4580
4581
4582
4583
4584
4585
4586
4587
4588
4589
4590
4591
4592
                ptr_string_buffer,
            )
        )
        actual_len = tmp_out_len.value
        # if buffer length is not long enough, re-allocate a buffer
        if actual_len > buffer_len:
            string_buffer = ctypes.create_string_buffer(actual_len)
            ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
            _safe_call(
                _LIB.LGBM_BoosterSaveModelToString(
                    self._handle,
                    ctypes.c_int(start_iteration),
                    ctypes.c_int(num_iteration),
                    ctypes.c_int(importance_type_int),
                    ctypes.c_int64(actual_len),
                    ctypes.byref(tmp_out_len),
                    ptr_string_buffer,
                )
            )
        ret = string_buffer.value.decode("utf-8")
4593
4594
        ret += _dump_pandas_categorical(self.pandas_categorical)
        return ret
4595

4596
4597
4598
4599
    def dump_model(
        self,
        num_iteration: Optional[int] = None,
        start_iteration: int = 0,
4600
4601
        importance_type: str = "split",
        object_hook: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
4602
    ) -> Dict[str, Any]:
Nikita Titov's avatar
Nikita Titov committed
4603
        """Dump Booster to JSON format.
wxchan's avatar
wxchan committed
4604

4605
4606
        Parameters
        ----------
4607
4608
4609
4610
        num_iteration : int or None, optional (default=None)
            Index of the iteration that should be dumped.
            If None, if the best iteration exists, it is dumped; otherwise, all iterations are dumped.
            If <= 0, all iterations are dumped.
Nikita Titov's avatar
Nikita Titov committed
4611
        start_iteration : int, optional (default=0)
4612
            Start index of the iteration that should be dumped.
4613
        importance_type : str, optional (default="split")
4614
4615
4616
            What type of feature importance should be dumped.
            If "split", result contains numbers of times the feature is used in a model.
            If "gain", result contains total gains of splits which use the feature.
4617
4618
4619
4620
4621
4622
4623
4624
4625
        object_hook : callable or None, optional (default=None)
            If not None, ``object_hook`` is a function called while parsing the json
            string returned by the C API. It may be used to alter the json, to store
            specific values while building the json structure. It avoids
            walking through the structure again. It saves a significant amount
            of time if the number of trees is huge.
            Signature is ``def object_hook(node: dict) -> dict``.
            None is equivalent to ``lambda node: node``.
            See documentation of ``json.loads()`` for further details.
4626

wxchan's avatar
wxchan committed
4627
4628
        Returns
        -------
4629
        json_repr : dict
Nikita Titov's avatar
Nikita Titov committed
4630
            JSON format of Booster.
wxchan's avatar
wxchan committed
4631
        """
4632
        if num_iteration is None:
4633
            num_iteration = self.best_iteration
4634
        importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
wxchan's avatar
wxchan committed
4635
        buffer_len = 1 << 20
4636
        tmp_out_len = ctypes.c_int64(0)
wxchan's avatar
wxchan committed
4637
        string_buffer = ctypes.create_string_buffer(buffer_len)
4638
        ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
4639
4640
        _safe_call(
            _LIB.LGBM_BoosterDumpModel(
4641
                self._handle,
4642
                ctypes.c_int(start_iteration),
Guolin Ke's avatar
Guolin Ke committed
4643
                ctypes.c_int(num_iteration),
4644
                ctypes.c_int(importance_type_int),
4645
                ctypes.c_int64(buffer_len),
wxchan's avatar
wxchan committed
4646
                ctypes.byref(tmp_out_len),
4647
4648
4649
4650
4651
4652
4653
4654
4655
4656
4657
4658
4659
4660
4661
4662
4663
4664
4665
4666
4667
4668
4669
4670
4671
4672
                ptr_string_buffer,
            )
        )
        actual_len = tmp_out_len.value
        # if buffer length is not long enough, reallocate a buffer
        if actual_len > buffer_len:
            string_buffer = ctypes.create_string_buffer(actual_len)
            ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
            _safe_call(
                _LIB.LGBM_BoosterDumpModel(
                    self._handle,
                    ctypes.c_int(start_iteration),
                    ctypes.c_int(num_iteration),
                    ctypes.c_int(importance_type_int),
                    ctypes.c_int64(actual_len),
                    ctypes.byref(tmp_out_len),
                    ptr_string_buffer,
                )
            )
        ret = json.loads(string_buffer.value.decode("utf-8"), object_hook=object_hook)
        ret["pandas_categorical"] = json.loads(
            json.dumps(
                self.pandas_categorical,
                default=_json_default_with_numpy,
            )
        )
4673
        return ret
wxchan's avatar
wxchan committed
4674

4675
4676
    def predict(
        self,
4677
        data: _LGBM_PredictDataType,
4678
4679
4680
4681
4682
4683
4684
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        raw_score: bool = False,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        data_has_header: bool = False,
        validate_features: bool = False,
4685
        **kwargs: Any,
4686
    ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
4687
        """Make a prediction.
wxchan's avatar
wxchan committed
4688
4689
4690

        Parameters
        ----------
4691
        data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame or scipy.sparse
4692
            Data source for prediction.
4693
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
4694
        start_iteration : int, optional (default=0)
4695
            Start index of the iteration to predict.
4696
            If <= 0, starts from the first iteration.
4697
        num_iteration : int or None, optional (default=None)
4698
4699
4700
4701
            Total number of iterations used in the prediction.
            If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
            otherwise, all iterations from ``start_iteration`` are used (no limits).
            If <= 0, all iterations from ``start_iteration`` are used (no limits).
4702
4703
4704
4705
        raw_score : bool, optional (default=False)
            Whether to predict raw scores.
        pred_leaf : bool, optional (default=False)
            Whether to predict leaf index.
4706
4707
        pred_contrib : bool, optional (default=False)
            Whether to predict feature contributions.
4708

Nikita Titov's avatar
Nikita Titov committed
4709
4710
4711
4712
4713
4714
4715
            .. note::

                If you want to get more explanations for your model's predictions using SHAP values,
                like SHAP interaction values,
                you can install the shap package (https://github.com/slundberg/shap).
                Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra
                column, where the last column is the expected value.
4716

4717
4718
        data_has_header : bool, optional (default=False)
            Whether the data has header.
4719
            Used only if data is str.
4720
4721
4722
        validate_features : bool, optional (default=False)
            If True, ensure that the features used to predict match the ones used to train.
            Used only if data is pandas DataFrame.
4723
4724
        **kwargs
            Other parameters for the prediction.
wxchan's avatar
wxchan committed
4725
4726
4727

        Returns
        -------
4728
        result : numpy array, scipy.sparse or list of scipy.sparse
4729
            Prediction result.
4730
            Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
wxchan's avatar
wxchan committed
4731
        """
4732
4733
4734
4735
        predictor = _InnerPredictor.from_booster(
            booster=self,
            pred_parameter=deepcopy(kwargs),
        )
4736
        if num_iteration is None:
4737
            if start_iteration <= 0:
4738
4739
4740
                num_iteration = self.best_iteration
            else:
                num_iteration = -1
4741
4742
4743
4744
4745
4746
4747
4748
        return predictor.predict(
            data=data,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            raw_score=raw_score,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            data_has_header=data_has_header,
4749
            validate_features=validate_features,
4750
        )
wxchan's avatar
wxchan committed
4751

4752
4753
    def refit(
        self,
4754
        data: _LGBM_TrainDataType,
4755
        label: _LGBM_LabelType,
4756
4757
        decay_rate: float = 0.9,
        reference: Optional[Dataset] = None,
4758
4759
4760
        weight: Optional[_LGBM_WeightType] = None,
        group: Optional[_LGBM_GroupType] = None,
        init_score: Optional[_LGBM_InitScoreType] = None,
4761
4762
        feature_name: _LGBM_FeatureNameConfiguration = "auto",
        categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
4763
4764
4765
        dataset_params: Optional[Dict[str, Any]] = None,
        free_raw_data: bool = True,
        validate_features: bool = False,
4766
        **kwargs: Any,
4767
    ) -> "Booster":
Guolin Ke's avatar
Guolin Ke committed
4768
4769
4770
4771
        """Refit the existing Booster by new data.

        Parameters
        ----------
4772
        data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Guolin Ke's avatar
Guolin Ke committed
4773
            Data source for refit.
4774
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
4775
        label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or pyarrow ChunkedArray
Guolin Ke's avatar
Guolin Ke committed
4776
4777
            Label for refit.
        decay_rate : float, optional (default=0.9)
4778
4779
            Decay rate of refit,
            will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
4780
4781
        reference : Dataset or None, optional (default=None)
            Reference for ``data``.
4782
4783
4784

            .. versionadded:: 4.0.0

4785
        weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
4786
            Weight for each ``data`` instance. Weights should be non-negative.
4787
4788
4789

            .. versionadded:: 4.0.0

4790
        group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
4791
4792
4793
4794
4795
            Group/query size for ``data``.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
4796
4797
4798

            .. versionadded:: 4.0.0

4799
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
4800
            Init score for ``data``.
4801
4802
4803

            .. versionadded:: 4.0.0

4804
4805
4806
        feature_name : list of str, or 'auto', optional (default="auto")
            Feature names for ``data``.
            If 'auto' and data is pandas DataFrame, data columns names are used.
4807
4808
4809

            .. versionadded:: 4.0.0

4810
4811
4812
4813
4814
        categorical_feature : list of str or int, or 'auto', optional (default="auto")
            Categorical features for ``data``.
            If list of int, interpreted as indices.
            If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
            If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
4815
            All values in categorical features will be cast to int32 and thus should be less than int32 max value (2147483647).
4816
4817
4818
            Large values could be memory consuming. Consider using consecutive integers starting from zero.
            All negative values in categorical features will be treated as missing values.
            The output cannot be monotonically constrained with respect to a categorical feature.
4819
            Floating point numbers in categorical features will be rounded towards 0.
4820
4821
4822

            .. versionadded:: 4.0.0

4823
4824
        dataset_params : dict or None, optional (default=None)
            Other parameters for Dataset ``data``.
4825
4826
4827

            .. versionadded:: 4.0.0

4828
4829
        free_raw_data : bool, optional (default=True)
            If True, raw data is freed after constructing inner Dataset for ``data``.
4830
4831
4832

            .. versionadded:: 4.0.0

4833
4834
4835
        validate_features : bool, optional (default=False)
            If True, ensure that the features used to refit the model match the original ones.
            Used only if data is pandas DataFrame.
4836
4837
4838

            .. versionadded:: 4.0.0

4839
4840
        **kwargs
            Other parameters for refit.
4841
            These parameters will be passed to ``predict`` method.
Guolin Ke's avatar
Guolin Ke committed
4842
4843
4844
4845
4846
4847

        Returns
        -------
        result : Booster
            Refitted Booster.
        """
4848
        if self.__set_objective_to_none:
4849
            raise LightGBMError("Cannot refit due to null objective function.")
4850
4851
        if dataset_params is None:
            dataset_params = {}
4852
        predictor = _InnerPredictor.from_booster(booster=self, pred_parameter=deepcopy(kwargs))
4853
        leaf_preds: np.ndarray = predictor.predict(  # type: ignore[assignment]
4854
4855
4856
            data=data,
            start_iteration=-1,
            pred_leaf=True,
4857
            validate_features=validate_features,
4858
        )
4859
        nrow, ncol = leaf_preds.shape
4860
        out_is_linear = ctypes.c_int(0)
4861
4862
4863
4864
4865
4866
        _safe_call(
            _LIB.LGBM_BoosterGetLinear(
                self._handle,
                ctypes.byref(out_is_linear),
            )
        )
Nikita Titov's avatar
Nikita Titov committed
4867
4868
4869
        new_params = _choose_param_value(
            main_param_name="linear_tree",
            params=self.params,
4870
            default_value=None,
Nikita Titov's avatar
Nikita Titov committed
4871
        )
4872
        new_params["linear_tree"] = bool(out_is_linear.value)
4873
4874
4875
4876
4877
4878
4879
4880
4881
4882
4883
4884
4885
        new_params.update(dataset_params)
        train_set = Dataset(
            data=data,
            label=label,
            reference=reference,
            weight=weight,
            group=group,
            init_score=init_score,
            feature_name=feature_name,
            categorical_feature=categorical_feature,
            params=new_params,
            free_raw_data=free_raw_data,
        )
4886
        new_params["refit_decay_rate"] = decay_rate
4887
        new_booster = Booster(new_params, train_set)
Guolin Ke's avatar
Guolin Ke committed
4888
        # Copy models
4889
4890
4891
4892
4893
4894
        _safe_call(
            _LIB.LGBM_BoosterMerge(
                new_booster._handle,
                predictor._handle,
            )
        )
Guolin Ke's avatar
Guolin Ke committed
4895
        leaf_preds = leaf_preds.reshape(-1)
4896
        ptr_data, _, _ = _c_int_array(leaf_preds)
4897
4898
4899
4900
4901
4902
4903
4904
        _safe_call(
            _LIB.LGBM_BoosterRefit(
                new_booster._handle,
                ptr_data,
                ctypes.c_int32(nrow),
                ctypes.c_int32(ncol),
            )
        )
4905
        new_booster._network = self._network
Guolin Ke's avatar
Guolin Ke committed
4906
4907
        return new_booster

4908
    def get_leaf_output(self, tree_id: int, leaf_id: int) -> float:
4909
4910
4911
4912
4913
4914
4915
4916
4917
4918
4919
4920
4921
4922
        """Get the output of a leaf.

        Parameters
        ----------
        tree_id : int
            The index of the tree.
        leaf_id : int
            The index of the leaf in the tree.

        Returns
        -------
        result : float
            The output of the leaf.
        """
4923
        ret = ctypes.c_double(0)
4924
4925
4926
4927
4928
4929
4930
4931
        _safe_call(
            _LIB.LGBM_BoosterGetLeafValue(
                self._handle,
                ctypes.c_int(tree_id),
                ctypes.c_int(leaf_id),
                ctypes.byref(ret),
            )
        )
4932
4933
        return ret.value

4934
4935
4936
4937
4938
    def set_leaf_output(
        self,
        tree_id: int,
        leaf_id: int,
        value: float,
4939
    ) -> "Booster":
4940
4941
        """Set the output of a leaf.

4942
4943
        .. versionadded:: 4.0.0

4944
4945
4946
4947
4948
4949
4950
4951
4952
4953
4954
4955
4956
4957
4958
4959
        Parameters
        ----------
        tree_id : int
            The index of the tree.
        leaf_id : int
            The index of the leaf in the tree.
        value : float
            Value to set as the output of the leaf.

        Returns
        -------
        self : Booster
            Booster with the leaf output set.
        """
        _safe_call(
            _LIB.LGBM_BoosterSetLeafValue(
4960
                self._handle,
4961
4962
                ctypes.c_int(tree_id),
                ctypes.c_int(leaf_id),
4963
                ctypes.c_double(value),
4964
4965
4966
4967
            )
        )
        return self

4968
    def num_feature(self) -> int:
4969
4970
4971
4972
4973
4974
4975
        """Get number of features.

        Returns
        -------
        num_feature : int
            The number of features.
        """
4976
        out_num_feature = ctypes.c_int(0)
4977
4978
4979
4980
4981
4982
        _safe_call(
            _LIB.LGBM_BoosterGetNumFeature(
                self._handle,
                ctypes.byref(out_num_feature),
            )
        )
4983
4984
        return out_num_feature.value

4985
    def feature_name(self) -> List[str]:
4986
        """Get names of features.
wxchan's avatar
wxchan committed
4987
4988
4989

        Returns
        -------
4990
        result : list of str
4991
            List with names of features.
wxchan's avatar
wxchan committed
4992
        """
4993
        num_feature = self.num_feature()
4994
        # Get name of features
wxchan's avatar
wxchan committed
4995
        tmp_out_len = ctypes.c_int(0)
4996
4997
        reserved_string_buffer_size = 255
        required_string_buffer_size = ctypes.c_size_t(0)
4998
        string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
4999
        ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
5000
5001
5002
5003
5004
5005
5006
5007
5008
5009
        _safe_call(
            _LIB.LGBM_BoosterGetFeatureNames(
                self._handle,
                ctypes.c_int(num_feature),
                ctypes.byref(tmp_out_len),
                ctypes.c_size_t(reserved_string_buffer_size),
                ctypes.byref(required_string_buffer_size),
                ptr_string_buffers,
            )
        )
wxchan's avatar
wxchan committed
5010
5011
        if num_feature != tmp_out_len.value:
            raise ValueError("Length of feature names doesn't equal with num_feature")
5012
5013
5014
5015
        actual_string_buffer_size = required_string_buffer_size.value
        # if buffer length is not long enough, reallocate buffers
        if reserved_string_buffer_size < actual_string_buffer_size:
            string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
5016
            ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
5017
5018
5019
5020
5021
5022
5023
5024
5025
5026
5027
            _safe_call(
                _LIB.LGBM_BoosterGetFeatureNames(
                    self._handle,
                    ctypes.c_int(num_feature),
                    ctypes.byref(tmp_out_len),
                    ctypes.c_size_t(actual_string_buffer_size),
                    ctypes.byref(required_string_buffer_size),
                    ptr_string_buffers,
                )
            )
        return [string_buffers[i].value.decode("utf-8") for i in range(num_feature)]
wxchan's avatar
wxchan committed
5028

5029
5030
    def feature_importance(
        self,
5031
5032
        importance_type: str = "split",
        iteration: Optional[int] = None,
5033
    ) -> np.ndarray:
5034
        """Get feature importances.
5035

5036
5037
        Parameters
        ----------
5038
        importance_type : str, optional (default="split")
5039
5040
5041
            How the importance is calculated.
            If "split", result contains numbers of times the feature is used in a model.
            If "gain", result contains total gains of splits which use the feature.
5042
5043
5044
5045
        iteration : int or None, optional (default=None)
            Limit number of iterations in the feature importance calculation.
            If None, if the best iteration exists, it is used; otherwise, all trees are used.
            If <= 0, all trees are used (no limits).
5046

5047
5048
        Returns
        -------
5049
5050
        result : numpy array
            Array with feature importances.
5051
        """
5052
5053
        if iteration is None:
            iteration = self.best_iteration
5054
        importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
5055
        result = np.empty(self.num_feature(), dtype=np.float64)
5056
5057
5058
5059
5060
5061
5062
5063
        _safe_call(
            _LIB.LGBM_BoosterFeatureImportance(
                self._handle,
                ctypes.c_int(iteration),
                ctypes.c_int(importance_type_int),
                result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
            )
        )
5064
        if importance_type_int == _C_API_FEATURE_IMPORTANCE_SPLIT:
5065
            return result.astype(np.int32)
5066
5067
        else:
            return result
5068

5069
5070
5071
5072
    def get_split_value_histogram(
        self,
        feature: Union[int, str],
        bins: Optional[Union[int, str]] = None,
5073
        xgboost_style: bool = False,
5074
    ) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray, pd_DataFrame]:
5075
5076
5077
5078
        """Get split value histogram for the specified feature.

        Parameters
        ----------
5079
        feature : int or str
5080
5081
            The feature name or index the histogram is calculated for.
            If int, interpreted as index.
5082
            If str, interpreted as name.
5083

Nikita Titov's avatar
Nikita Titov committed
5084
5085
5086
            .. warning::

                Categorical features are not supported.
5087

5088
        bins : int, str or None, optional (default=None)
5089
5090
5091
            The maximum number of bins.
            If None, or int and > number of unique split values and ``xgboost_style=True``,
            the number of bins equals number of unique split values.
5092
            If str, it should be one from the list of the supported values by ``numpy.histogram()`` function.
5093
5094
5095
5096
5097
5098
5099
5100
5101
5102
5103
5104
5105
5106
        xgboost_style : bool, optional (default=False)
            Whether the returned result should be in the same form as it is in XGBoost.
            If False, the returned value is tuple of 2 numpy arrays as it is in ``numpy.histogram()`` function.
            If True, the returned value is matrix, in which the first column is the right edges of non-empty bins
            and the second one is the histogram values.

        Returns
        -------
        result_tuple : tuple of 2 numpy arrays
            If ``xgboost_style=False``, the values of the histogram of used splitting values for the specified feature
            and the bin edges.
        result_array_like : numpy array or pandas DataFrame (if pandas is installed)
            If ``xgboost_style=True``, the histogram of used splitting values for the specified feature.
        """
5107

5108
        def add(root: Dict[str, Any]) -> None:
5109
            """Recursively add thresholds."""
5110
            if "split_index" in root:  # non-leaf
5111
                if feature_names is not None and isinstance(feature, str):
5112
                    split_feature = feature_names[root["split_feature"]]
5113
                else:
5114
                    split_feature = root["split_feature"]
5115
                if split_feature == feature:
5116
5117
                    if isinstance(root["threshold"], str):
                        raise LightGBMError("Cannot compute split value histogram for the categorical feature")
5118
                    else:
5119
5120
5121
                        values.append(root["threshold"])
                add(root["left_child"])
                add(root["right_child"])
5122
5123

        model = self.dump_model()
5124
5125
        feature_names = model.get("feature_names")
        tree_infos = model["tree_info"]
5126
        values: List[float] = []
5127
        for tree_info in tree_infos:
5128
            add(tree_info["tree_structure"])
5129

5130
        if bins is None or isinstance(bins, int) and xgboost_style:
5131
5132
5133
5134
5135
5136
5137
            n_unique = len(np.unique(values))
            bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
        hist, bin_edges = np.histogram(values, bins=bins)
        if xgboost_style:
            ret = np.column_stack((bin_edges[1:], hist))
            ret = ret[ret[:, 1] > 0]
            if PANDAS_INSTALLED:
5138
                return pd_DataFrame(ret, columns=["SplitValue", "Count"])
5139
5140
5141
5142
5143
            else:
                return ret
        else:
            return hist, bin_edges

5144
5145
5146
5147
    def __inner_eval(
        self,
        data_name: str,
        data_idx: int,
5148
        feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]],
5149
    ) -> List[_LGBM_BoosterEvalMethodResultType]:
5150
        """Evaluate training or validation data."""
wxchan's avatar
wxchan committed
5151
        if data_idx >= self.__num_dataset:
5152
            raise ValueError("Data_idx should be smaller than number of dataset")
wxchan's avatar
wxchan committed
5153
5154
5155
        self.__get_eval_info()
        ret = []
        if self.__num_inner_eval > 0:
5156
            result = np.empty(self.__num_inner_eval, dtype=np.float64)
Guolin Ke's avatar
Guolin Ke committed
5157
            tmp_out_len = ctypes.c_int(0)
5158
5159
5160
5161
5162
5163
5164
5165
            _safe_call(
                _LIB.LGBM_BoosterGetEval(
                    self._handle,
                    ctypes.c_int(data_idx),
                    ctypes.byref(tmp_out_len),
                    result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
                )
            )
wxchan's avatar
wxchan committed
5166
            if tmp_out_len.value != self.__num_inner_eval:
5167
                raise ValueError("Wrong length of eval results")
5168
            for i in range(self.__num_inner_eval):
5169
                ret.append((data_name, self.__name_inner_eval[i], result[i], self.__higher_better_inner_eval[i]))
5170
5171
        if callable(feval):
            feval = [feval]
wxchan's avatar
wxchan committed
5172
5173
5174
5175
5176
        if feval is not None:
            if data_idx == 0:
                cur_data = self.train_set
            else:
                cur_data = self.valid_sets[data_idx - 1]
5177
5178
5179
5180
5181
5182
5183
5184
5185
            for eval_function in feval:
                if eval_function is None:
                    continue
                feval_ret = eval_function(self.__inner_predict(data_idx), cur_data)
                if isinstance(feval_ret, list):
                    for eval_name, val, is_higher_better in feval_ret:
                        ret.append((data_name, eval_name, val, is_higher_better))
                else:
                    eval_name, val, is_higher_better = feval_ret
wxchan's avatar
wxchan committed
5186
5187
5188
                    ret.append((data_name, eval_name, val, is_higher_better))
        return ret

5189
    def __inner_predict(self, data_idx: int) -> np.ndarray:
5190
        """Predict for training and validation dataset."""
wxchan's avatar
wxchan committed
5191
        if data_idx >= self.__num_dataset:
5192
            raise ValueError("Data_idx should be smaller than number of dataset")
wxchan's avatar
wxchan committed
5193
5194
5195
5196
5197
        if self.__inner_predict_buffer[data_idx] is None:
            if data_idx == 0:
                n_preds = self.train_set.num_data() * self.__num_class
            else:
                n_preds = self.valid_sets[data_idx - 1].num_data() * self.__num_class
5198
            self.__inner_predict_buffer[data_idx] = np.empty(n_preds, dtype=np.float64)
5199
        # avoid to predict many time in one iteration
wxchan's avatar
wxchan committed
5200
5201
        if not self.__is_predicted_cur_iter[data_idx]:
            tmp_out_len = ctypes.c_int64(0)
5202
            data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))  # type: ignore[union-attr]
5203
5204
5205
5206
5207
5208
5209
5210
            _safe_call(
                _LIB.LGBM_BoosterGetPredict(
                    self._handle,
                    ctypes.c_int(data_idx),
                    ctypes.byref(tmp_out_len),
                    data_ptr,
                )
            )
5211
            if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):  # type: ignore[arg-type]
5212
                raise ValueError(f"Wrong length of predict results for data {data_idx}")
wxchan's avatar
wxchan committed
5213
            self.__is_predicted_cur_iter[data_idx] = True
5214
        result: np.ndarray = self.__inner_predict_buffer[data_idx]  # type: ignore[assignment]
5215
5216
        if self.__num_class > 1:
            num_data = result.size // self.__num_class
5217
            result = result.reshape(num_data, self.__num_class, order="F")
5218
        return result
wxchan's avatar
wxchan committed
5219

5220
    def __get_eval_info(self) -> None:
5221
        """Get inner evaluation count and names."""
wxchan's avatar
wxchan committed
5222
5223
        if self.__need_reload_eval_info:
            self.__need_reload_eval_info = False
Guolin Ke's avatar
Guolin Ke committed
5224
            out_num_eval = ctypes.c_int(0)
5225
            # Get num of inner evals
5226
5227
5228
5229
5230
5231
            _safe_call(
                _LIB.LGBM_BoosterGetEvalCounts(
                    self._handle,
                    ctypes.byref(out_num_eval),
                )
            )
wxchan's avatar
wxchan committed
5232
5233
            self.__num_inner_eval = out_num_eval.value
            if self.__num_inner_eval > 0:
5234
                # Get name of eval metrics
Guolin Ke's avatar
Guolin Ke committed
5235
                tmp_out_len = ctypes.c_int(0)
5236
5237
5238
                reserved_string_buffer_size = 255
                required_string_buffer_size = ctypes.c_size_t(0)
                string_buffers = [
5239
                    ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(self.__num_inner_eval)
5240
                ]
5241
                ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
5242
5243
5244
5245
5246
5247
5248
5249
5250
5251
                _safe_call(
                    _LIB.LGBM_BoosterGetEvalNames(
                        self._handle,
                        ctypes.c_int(self.__num_inner_eval),
                        ctypes.byref(tmp_out_len),
                        ctypes.c_size_t(reserved_string_buffer_size),
                        ctypes.byref(required_string_buffer_size),
                        ptr_string_buffers,
                    )
                )
wxchan's avatar
wxchan committed
5252
                if self.__num_inner_eval != tmp_out_len.value:
5253
                    raise ValueError("Length of eval names doesn't equal with num_evals")
5254
5255
5256
5257
5258
5259
                actual_string_buffer_size = required_string_buffer_size.value
                # if buffer length is not long enough, reallocate buffers
                if reserved_string_buffer_size < actual_string_buffer_size:
                    string_buffers = [
                        ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(self.__num_inner_eval)
                    ]
5260
5261
5262
5263
5264
5265
5266
5267
5268
5269
5270
5271
5272
5273
                    ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(
                        *map(ctypes.addressof, string_buffers)
                    )  # type: ignore[misc]
                    _safe_call(
                        _LIB.LGBM_BoosterGetEvalNames(
                            self._handle,
                            ctypes.c_int(self.__num_inner_eval),
                            ctypes.byref(tmp_out_len),
                            ctypes.c_size_t(actual_string_buffer_size),
                            ctypes.byref(required_string_buffer_size),
                            ptr_string_buffers,
                        )
                    )
                self.__name_inner_eval = [string_buffers[i].value.decode("utf-8") for i in range(self.__num_inner_eval)]
5274
                self.__higher_better_inner_eval = [
5275
                    name.startswith(("auc", "ndcg@", "map@", "average_precision")) for name in self.__name_inner_eval
5276
                ]