"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "ceb9986192e19a6fc41d75a58b5f5f824fd2bbff"
basic.py 194 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Wrapper for C API of LightGBM."""
3
import abc
wxchan's avatar
wxchan committed
4
import ctypes
5
import inspect
6
import json
wxchan's avatar
wxchan committed
7
import warnings
8
from collections import OrderedDict
9
from copy import deepcopy
10
from enum import Enum
11
from functools import wraps
12
from os import SEEK_END, environ
13
14
from os.path import getsize
from pathlib import Path
15
from tempfile import NamedTemporaryFile
16
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
wxchan's avatar
wxchan committed
17
18
19
20

import numpy as np
import scipy.sparse

21
from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat,
22
23
                     dt_DataTable, pa_Array, pa_chunked_array, pa_ChunkedArray, pa_compute, pa_Table,
                     pd_CategoricalDtype, pd_DataFrame, pd_Series)
wxchan's avatar
wxchan committed
24
25
from .libpath import find_lib_path

26
27
28
if TYPE_CHECKING:
    from typing import Literal

29
30
31
32
33
34
35
    # typing.TypeGuard was only introduced in Python 3.10
    try:
        from typing import TypeGuard
    except ImportError:
        from typing_extensions import TypeGuard


36
37
38
39
40
41
42
43
44
__all__ = [
    'Booster',
    'Dataset',
    'LGBMDeprecationWarning',
    'LightGBMError',
    'register_logger',
    'Sequence',
]

45
_BoosterHandle = ctypes.c_void_p
46
_DatasetHandle = ctypes.c_void_p
47
48
49
50
_ctypes_int_ptr = Union[
    "ctypes._Pointer[ctypes.c_int32]",
    "ctypes._Pointer[ctypes.c_int64]"
]
51
52
53
54
_ctypes_int_array = Union[
    "ctypes.Array[ctypes._Pointer[ctypes.c_int32]]",
    "ctypes.Array[ctypes._Pointer[ctypes.c_int64]]"
]
55
56
57
58
59
60
61
62
_ctypes_float_ptr = Union[
    "ctypes._Pointer[ctypes.c_float]",
    "ctypes._Pointer[ctypes.c_double]"
]
_ctypes_float_array = Union[
    "ctypes.Array[ctypes._Pointer[ctypes.c_float]]",
    "ctypes.Array[ctypes._Pointer[ctypes.c_double]]"
]
63
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
64
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
65
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
66
_LGBM_BoosterEvalMethodResultWithStandardDeviationType = Tuple[str, str, float, bool, float]
67
68
_LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], "Literal['auto']"]
_LGBM_FeatureNameConfiguration = Union[List[str], "Literal['auto']"]
69
70
71
72
_LGBM_GroupType = Union[
    List[float],
    List[int],
    np.ndarray,
73
74
75
    pd_Series,
    pa_Array,
    pa_ChunkedArray,
76
]
77
78
79
80
_LGBM_PositionType = Union[
    np.ndarray,
    pd_Series
]
81
82
83
84
85
86
_LGBM_InitScoreType = Union[
    List[float],
    List[List[float]],
    np.ndarray,
    pd_Series,
    pd_DataFrame,
87
88
89
    pa_Table,
    pa_Array,
    pa_ChunkedArray,
90
]
91
92
93
94
95
96
97
98
99
_LGBM_TrainDataType = Union[
    str,
    Path,
    np.ndarray,
    pd_DataFrame,
    dt_DataTable,
    scipy.sparse.spmatrix,
    "Sequence",
    List["Sequence"],
100
101
    List[np.ndarray],
    pa_Table
102
]
103
_LGBM_LabelType = Union[
104
105
    List[float],
    List[int],
106
107
    np.ndarray,
    pd_Series,
108
109
110
    pd_DataFrame,
    pa_Array,
    pa_ChunkedArray,
111
]
112
113
114
115
116
117
_LGBM_PredictDataType = Union[
    str,
    Path,
    np.ndarray,
    pd_DataFrame,
    dt_DataTable,
118
119
    scipy.sparse.spmatrix,
    pa_Table,
120
]
121
122
123
124
_LGBM_WeightType = Union[
    List[float],
    List[int],
    np.ndarray,
125
126
127
    pd_Series,
    pa_Array,
    pa_ChunkedArray,
128
]
129
130
131
ZERO_THRESHOLD = 1e-35


132
133
134
135
def _is_zero(x: float) -> bool:
    return -ZERO_THRESHOLD <= x <= ZERO_THRESHOLD


136
def _get_sample_count(total_nrow: int, params: str) -> int:
137
138
139
    sample_cnt = ctypes.c_int(0)
    _safe_call(_LIB.LGBM_GetSampleCount(
        ctypes.c_int32(total_nrow),
140
        _c_str(params),
141
142
143
144
        ctypes.byref(sample_cnt),
    ))
    return sample_cnt.value

wxchan's avatar
wxchan committed
145

146
147
148
149
150
151
class _MissingType(Enum):
    NONE = 'None'
    NAN = 'NaN'
    ZERO = 'Zero'


152
class _DummyLogger:
153
    def info(self, msg: str) -> None:
154
        print(msg)  # noqa: T201
155

156
    def warning(self, msg: str) -> None:
157
158
159
        warnings.warn(msg, stacklevel=3)


160
161
162
_LOGGER: Any = _DummyLogger()
_INFO_METHOD_NAME = "info"
_WARNING_METHOD_NAME = "warning"
163
164


165
166
167
168
def _has_method(logger: Any, method_name: str) -> bool:
    return callable(getattr(logger, method_name, None))


169
170
171
def register_logger(
    logger: Any, info_method_name: str = "info", warning_method_name: str = "warning"
) -> None:
172
173
174
175
    """Register custom logger.

    Parameters
    ----------
176
    logger : Any
177
        Custom logger.
178
179
180
181
    info_method_name : str, optional (default="info")
        Method used to log info messages.
    warning_method_name : str, optional (default="warning")
        Method used to log warning messages.
182
    """
183
184
185
186
187
188
    if not _has_method(logger, info_method_name) or not _has_method(logger, warning_method_name):
        raise TypeError(
            f"Logger must provide '{info_method_name}' and '{warning_method_name}' method"
        )

    global _LOGGER, _INFO_METHOD_NAME, _WARNING_METHOD_NAME
189
    _LOGGER = logger
190
191
    _INFO_METHOD_NAME = info_method_name
    _WARNING_METHOD_NAME = warning_method_name
192
193


194
def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], None]:
195
    """Join log messages from native library which come by chunks."""
196
    msg_normalized: List[str] = []
197
198

    @wraps(func)
199
    def wrapper(msg: str) -> None:
200
201
202
203
204
205
206
207
208
209
210
        nonlocal msg_normalized
        if msg.strip() == '':
            msg = ''.join(msg_normalized)
            msg_normalized = []
            return func(msg)
        else:
            msg_normalized.append(msg)

    return wrapper


211
def _log_info(msg: str) -> None:
212
    getattr(_LOGGER, _INFO_METHOD_NAME)(msg)
213
214


215
def _log_warning(msg: str) -> None:
216
    getattr(_LOGGER, _WARNING_METHOD_NAME)(msg)
217
218
219


@_normalize_native_string
220
def _log_native(msg: str) -> None:
221
    getattr(_LOGGER, _INFO_METHOD_NAME)(msg)
222
223


224
def _log_callback(msg: bytes) -> None:
225
    """Redirect logs from native library into Python."""
226
    _log_native(str(msg.decode('utf-8')))
227
228


229
def _load_lib() -> ctypes.CDLL:
230
    """Load LightGBM library."""
wxchan's avatar
wxchan committed
231
232
233
    lib_path = find_lib_path()
    lib = ctypes.cdll.LoadLibrary(lib_path[0])
    lib.LGBM_GetLastError.restype = ctypes.c_char_p
234
    callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
235
    lib.callback = callback(_log_callback)  # type: ignore[attr-defined]
236
    if lib.LGBM_RegisterLogCallback(lib.callback) != 0:
237
        raise LightGBMError(lib.LGBM_GetLastError().decode('utf-8'))
wxchan's avatar
wxchan committed
238
239
    return lib

wxchan's avatar
wxchan committed
240

241
242
243
244
245
246
247
# we don't need lib_lightgbm while building docs
_LIB: ctypes.CDLL
if environ.get('LIGHTGBM_BUILD_DOC', False):
    from unittest.mock import Mock  # isort: skip
    _LIB = Mock(ctypes.CDLL)  # type: ignore
else:
    _LIB = _load_lib()
wxchan's avatar
wxchan committed
248

wxchan's avatar
wxchan committed
249

250
_NUMERIC_TYPES = (int, float, bool)
251
_ArrayLike = Union[List, np.ndarray, pd_Series]
252
253


254
def _safe_call(ret: int) -> None:
255
256
    """Check the return value from C API call.

wxchan's avatar
wxchan committed
257
258
259
    Parameters
    ----------
    ret : int
260
        The return value from C API calls.
wxchan's avatar
wxchan committed
261
262
    """
    if ret != 0:
263
        raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8'))
wxchan's avatar
wxchan committed
264

wxchan's avatar
wxchan committed
265

266
def _is_numeric(obj: Any) -> bool:
267
    """Check whether object is a number or not, include numpy number, etc."""
wxchan's avatar
wxchan committed
268
269
270
    try:
        float(obj)
        return True
wxchan's avatar
wxchan committed
271
272
273
    except (TypeError, ValueError):
        # TypeError: obj is not a string or a number
        # ValueError: invalid literal
wxchan's avatar
wxchan committed
274
275
        return False

wxchan's avatar
wxchan committed
276

277
def _is_numpy_1d_array(data: Any) -> bool:
278
    """Check whether data is a numpy 1-D array."""
279
    return isinstance(data, np.ndarray) and len(data.shape) == 1
wxchan's avatar
wxchan committed
280

wxchan's avatar
wxchan committed
281

282
def _is_numpy_column_array(data: Any) -> bool:
283
284
285
286
287
288
289
    """Check whether data is a column numpy array."""
    if not isinstance(data, np.ndarray):
        return False
    shape = data.shape
    return len(shape) == 2 and shape[1] == 1


290
def _cast_numpy_array_to_dtype(array: np.ndarray, dtype: "np.typing.DTypeLike") -> np.ndarray:
291
    """Cast numpy array to given dtype."""
292
293
294
295
296
    if array.dtype == dtype:
        return array
    return array.astype(dtype=dtype, copy=False)


297
def _is_1d_list(data: Any) -> bool:
298
    """Check whether data is a 1-D list."""
299
    return isinstance(data, list) and (not data or _is_numeric(data[0]))
wxchan's avatar
wxchan committed
300

wxchan's avatar
wxchan committed
301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
def _is_list_of_numpy_arrays(data: Any) -> "TypeGuard[List[np.ndarray]]":
    return (
        isinstance(data, list)
        and all(isinstance(x, np.ndarray) for x in data)
    )


def _is_list_of_sequences(data: Any) -> "TypeGuard[List[Sequence]]":
    return (
        isinstance(data, list)
        and all(isinstance(x, Sequence) for x in data)
    )


316
317
318
def _is_1d_collection(data: Any) -> bool:
    """Check whether data is a 1-D collection."""
    return (
319
        _is_numpy_1d_array(data)
320
        or _is_numpy_column_array(data)
321
        or _is_1d_list(data)
322
323
324
325
        or isinstance(data, pd_Series)
    )


326
327
def _list_to_1d_numpy(
    data: Any,
328
329
    dtype: "np.typing.DTypeLike",
    name: str
330
) -> np.ndarray:
331
    """Convert data to numpy 1-D array."""
332
    if _is_numpy_1d_array(data):
333
        return _cast_numpy_array_to_dtype(data, dtype)
334
    elif _is_numpy_column_array(data):
335
336
        _log_warning('Converting column-vector to 1d array')
        array = data.ravel()
337
        return _cast_numpy_array_to_dtype(array, dtype)
338
    elif _is_1d_list(data):
wxchan's avatar
wxchan committed
339
        return np.array(data, dtype=dtype, copy=False)
340
    elif isinstance(data, pd_Series):
341
        _check_for_bad_pandas_dtypes(data.to_frame().dtypes)
342
        return np.array(data, dtype=dtype, copy=False)  # SparseArray should be supported as well
wxchan's avatar
wxchan committed
343
    else:
344
345
        raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n"
                        "It should be list, numpy 1-D array or pandas Series")
wxchan's avatar
wxchan committed
346

wxchan's avatar
wxchan committed
347

348
349
350
351
352
353
354
def _is_numpy_2d_array(data: Any) -> bool:
    """Check whether data is a numpy 2-D array."""
    return isinstance(data, np.ndarray) and len(data.shape) == 2 and data.shape[1] > 1


def _is_2d_list(data: Any) -> bool:
    """Check whether data is a 2-D list."""
355
    return isinstance(data, list) and len(data) > 0 and _is_1d_list(data[0])
356
357
358
359
360
361
362
363
364
365
366


def _is_2d_collection(data: Any) -> bool:
    """Check whether data is a 2-D collection."""
    return (
        _is_numpy_2d_array(data)
        or _is_2d_list(data)
        or isinstance(data, pd_DataFrame)
    )


367
def _is_pyarrow_array(data: Any) -> "TypeGuard[Union[pa_Array, pa_ChunkedArray]]":
368
369
370
371
    """Check whether data is a PyArrow array."""
    return isinstance(data, (pa_Array, pa_ChunkedArray))


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def _is_pyarrow_table(data: Any) -> bool:
    """Check whether data is a PyArrow table."""
    return isinstance(data, pa_Table)


class _ArrowCArray:
    """Simple wrapper around the C representation of an Arrow type."""

    n_chunks: int
    chunks: arrow_cffi.CData
    schema: arrow_cffi.CData

    def __init__(self, n_chunks: int, chunks: arrow_cffi.CData, schema: arrow_cffi.CData):
        self.n_chunks = n_chunks
        self.chunks = chunks
        self.schema = schema

    @property
    def chunks_ptr(self) -> int:
        """Returns the address of the pointer to the list of chunks making up the array."""
        return int(arrow_cffi.cast("uintptr_t", arrow_cffi.addressof(self.chunks[0])))

    @property
    def schema_ptr(self) -> int:
        """Returns the address of the pointer to the schema of the array."""
        return int(arrow_cffi.cast("uintptr_t", self.schema))


def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray:
    """Export an Arrow type to its C representation."""
    # Obtain objects to export
403
404
405
406
407
    if isinstance(data, pa_Array):
        export_objects = [data]
    elif isinstance(data, pa_ChunkedArray):
        export_objects = data.chunks
    elif isinstance(data, pa_Table):
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        export_objects = data.to_batches()
    else:
        raise ValueError(f"data of type '{type(data)}' cannot be exported to Arrow")

    # Prepare export
    chunks = arrow_cffi.new("struct ArrowArray[]", len(export_objects))
    schema = arrow_cffi.new("struct ArrowSchema*")

    # Export all objects
    for i, obj in enumerate(export_objects):
        chunk_ptr = int(arrow_cffi.cast("uintptr_t", arrow_cffi.addressof(chunks[i])))
        if i == 0:
            schema_ptr = int(arrow_cffi.cast("uintptr_t", schema))
            obj._export_to_c(chunk_ptr, schema_ptr)
        else:
            obj._export_to_c(chunk_ptr)

    return _ArrowCArray(len(chunks), chunks, schema)



429
430
def _data_to_2d_numpy(
    data: Any,
431
432
    dtype: "np.typing.DTypeLike",
    name: str
433
) -> np.ndarray:
434
435
    """Convert data to numpy 2-D array."""
    if _is_numpy_2d_array(data):
436
        return _cast_numpy_array_to_dtype(data, dtype)
437
438
439
    if _is_2d_list(data):
        return np.array(data, dtype=dtype)
    if isinstance(data, pd_DataFrame):
440
        _check_for_bad_pandas_dtypes(data.dtypes)
441
        return _cast_numpy_array_to_dtype(data.values, dtype)
442
443
444
445
    raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n"
                    "It should be list of lists, numpy 2-D array or pandas DataFrame")


446
def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
447
    """Convert a ctypes float pointer array to a numpy array."""
wxchan's avatar
wxchan committed
448
    if isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
449
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
wxchan's avatar
wxchan committed
450
    else:
451
        raise RuntimeError('Expected float pointer')
wxchan's avatar
wxchan committed
452

Guolin Ke's avatar
Guolin Ke committed
453

454
def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
455
    """Convert a ctypes double pointer array to a numpy array."""
Guolin Ke's avatar
Guolin Ke committed
456
    if isinstance(cptr, ctypes.POINTER(ctypes.c_double)):
457
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
Guolin Ke's avatar
Guolin Ke committed
458
459
460
    else:
        raise RuntimeError('Expected double pointer')

wxchan's avatar
wxchan committed
461

462
def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
463
    """Convert a ctypes int pointer array to a numpy array."""
wxchan's avatar
wxchan committed
464
    if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)):
465
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
wxchan's avatar
wxchan committed
466
    else:
467
468
469
        raise RuntimeError('Expected int32 pointer')


470
def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
471
472
    """Convert a ctypes int pointer array to a numpy array."""
    if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)):
473
        return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
474
475
    else:
        raise RuntimeError('Expected int64 pointer')
wxchan's avatar
wxchan committed
476

wxchan's avatar
wxchan committed
477

478
def _c_str(string: str) -> ctypes.c_char_p:
479
    """Convert a Python string to C string."""
wxchan's avatar
wxchan committed
480
481
    return ctypes.c_char_p(string.encode('utf-8'))

wxchan's avatar
wxchan committed
482

483
def _c_array(ctype: type, values: List[Any]) -> ctypes.Array:
484
    """Convert a Python array to C array."""
485
    return (ctype * len(values))(*values)  # type: ignore[operator]
wxchan's avatar
wxchan committed
486

wxchan's avatar
wxchan committed
487

488
def _json_default_with_numpy(obj: Any) -> Any:
489
490
491
492
493
494
495
496
497
    """Convert numpy classes to JSON serializable objects."""
    if isinstance(obj, (np.integer, np.floating, np.bool_)):
        return obj.item()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


498
499
500
501
502
503
504
505
def _to_string(x: Union[int, float, str, List]) -> str:
    if isinstance(x, list):
        val_list = ",".join(str(val) for val in x)
        return f"[{val_list}]"
    else:
        return str(x)


506
def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str:
507
    """Convert Python dictionary to string, which is passed to C API."""
508
    if data is None or not data:
wxchan's avatar
wxchan committed
509
510
511
        return ""
    pairs = []
    for key, val in data.items():
512
        if isinstance(val, (list, tuple, set)) or _is_numpy_1d_array(val):
513
            pairs.append(f"{key}={','.join(map(_to_string, val))}")
514
        elif isinstance(val, (str, Path, _NUMERIC_TYPES)) or _is_numeric(val):
515
            pairs.append(f"{key}={val}")
516
        elif val is not None:
517
            raise TypeError(f'Unknown type of parameter:{key}, got:{type(val).__name__}')
wxchan's avatar
wxchan committed
518
    return ' '.join(pairs)
519

wxchan's avatar
wxchan committed
520

521
class _TempFile:
522
523
    """Proxy class to workaround errors on Windows."""

524
525
526
    def __enter__(self):
        with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f:
            self.name = f.name
527
            self.path = Path(self.name)
528
        return self
wxchan's avatar
wxchan committed
529

530
    def __exit__(self, exc_type, exc_val, exc_tb):
531
532
        if self.path.is_file():
            self.path.unlink()
533

wxchan's avatar
wxchan committed
534

535
class LightGBMError(Exception):
536
537
    """Error thrown by LightGBM."""

538
539
540
    pass


541
542
543
544
545
546
547
548
# DeprecationWarning is not shown by default, so let's create our own with higher level
class LGBMDeprecationWarning(UserWarning):
    """Custom deprecation warning."""

    pass


class _ConfigAliases:
549
550
551
552
    # lazy evaluation to allow import without dynamic library, e.g., for docs generation
    aliases = None

    @staticmethod
553
    def _get_all_param_aliases() -> Dict[str, List[str]]:
554
555
556
        buffer_len = 1 << 20
        tmp_out_len = ctypes.c_int64(0)
        string_buffer = ctypes.create_string_buffer(buffer_len)
557
        ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
558
559
560
561
562
563
564
565
        _safe_call(_LIB.LGBM_DumpParamAliases(
            ctypes.c_int64(buffer_len),
            ctypes.byref(tmp_out_len),
            ptr_string_buffer))
        actual_len = tmp_out_len.value
        # if buffer length is not long enough, re-allocate a buffer
        if actual_len > buffer_len:
            string_buffer = ctypes.create_string_buffer(actual_len)
566
            ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
567
568
569
570
            _safe_call(_LIB.LGBM_DumpParamAliases(
                ctypes.c_int64(actual_len),
                ctypes.byref(tmp_out_len),
                ptr_string_buffer))
571
        return json.loads(
572
            string_buffer.value.decode('utf-8'),
573
            object_hook=lambda obj: {k: [k] + v for k, v in obj.items()}
574
        )
575
576

    @classmethod
577
578
579
    def get(cls, *args) -> Set[str]:
        if cls.aliases is None:
            cls.aliases = cls._get_all_param_aliases()
580
581
        ret = set()
        for i in args:
582
            ret.update(cls.get_sorted(i))
583
584
        return ret

585
586
587
588
589
590
    @classmethod
    def get_sorted(cls, name: str) -> List[str]:
        if cls.aliases is None:
            cls.aliases = cls._get_all_param_aliases()
        return cls.aliases.get(name, [name])

591
    @classmethod
592
593
594
    def get_by_alias(cls, *args) -> Set[str]:
        if cls.aliases is None:
            cls.aliases = cls._get_all_param_aliases()
595
596
597
598
        ret = set(args)
        for arg in args:
            for aliases in cls.aliases.values():
                if arg in aliases:
599
                    ret.update(aliases)
600
601
602
                    break
        return ret

603

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any) -> Dict[str, Any]:
    """Get a single parameter value, accounting for aliases.

    Parameters
    ----------
    main_param_name : str
        Name of the main parameter to get a value for. One of the keys of ``_ConfigAliases``.
    params : dict
        Dictionary of LightGBM parameters.
    default_value : Any
        Default value to use for the parameter, if none is found in ``params``.

    Returns
    -------
    params : dict
        A ``params`` dict with exactly one value for ``main_param_name``, and all aliases ``main_param_name`` removed.
        If both ``main_param_name`` and one or more aliases for it are found, the value of ``main_param_name`` will be preferred.
    """
    # avoid side effects on passed-in parameters
    params = deepcopy(params)

625
626
    aliases = _ConfigAliases.get_sorted(main_param_name)
    aliases = [a for a in aliases if a != main_param_name]
627
628

    # if main_param_name was provided, keep that value and remove all aliases
629
    if main_param_name in params.keys():
630
631
632
        for param in aliases:
            params.pop(param, None)
        return params
633

634
635
636
637
638
    # if main param name was not found, search for an alias
    for param in aliases:
        if param in params.keys():
            params[main_param_name] = params[param]
            break
639

640
641
642
643
644
645
646
    if main_param_name in params.keys():
        for param in aliases:
            params.pop(param, None)
        return params

    # neither of main_param_name, aliases were found
    params[main_param_name] = default_value
647
648
649
650

    return params


651
_MAX_INT32 = (1 << 31) - 1
652

653
"""Macro definition of data type in C API of LightGBM"""
654
655
656
657
_C_API_DTYPE_FLOAT32 = 0
_C_API_DTYPE_FLOAT64 = 1
_C_API_DTYPE_INT32 = 2
_C_API_DTYPE_INT64 = 3
Guolin Ke's avatar
Guolin Ke committed
658

659
"""Matrix is row major in Python"""
660
_C_API_IS_ROW_MAJOR = 1
wxchan's avatar
wxchan committed
661

662
"""Macro definition of prediction type in C API of LightGBM"""
663
664
665
666
_C_API_PREDICT_NORMAL = 0
_C_API_PREDICT_RAW_SCORE = 1
_C_API_PREDICT_LEAF_INDEX = 2
_C_API_PREDICT_CONTRIB = 3
wxchan's avatar
wxchan committed
667

668
"""Macro definition of sparse matrix type"""
669
670
_C_API_MATRIX_TYPE_CSR = 0
_C_API_MATRIX_TYPE_CSC = 1
671

672
"""Macro definition of feature importance type"""
673
674
_C_API_FEATURE_IMPORTANCE_SPLIT = 0
_C_API_FEATURE_IMPORTANCE_GAIN = 1
675

676
"""Data type of data field"""
677
678
679
680
_FIELD_TYPE_MAPPER = {
    "label": _C_API_DTYPE_FLOAT32,
    "weight": _C_API_DTYPE_FLOAT32,
    "init_score": _C_API_DTYPE_FLOAT64,
681
682
    "group": _C_API_DTYPE_INT32,
    "position": _C_API_DTYPE_INT32
683
}
wxchan's avatar
wxchan committed
684

685
"""String name to int feature importance type mapper"""
686
687
688
689
_FEATURE_IMPORTANCE_TYPE_MAPPER = {
    "split": _C_API_FEATURE_IMPORTANCE_SPLIT,
    "gain": _C_API_FEATURE_IMPORTANCE_GAIN
}
690

wxchan's avatar
wxchan committed
691

692
def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray:
693
    """Fix the memory of multi-dimensional sliced object."""
694
    if isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
695
        if not data.flags.c_contiguous:
696
697
            _log_warning("Usage of np.ndarray subset (sliced data) is not recommended "
                         "due to it will double the peak memory cost in LightGBM.")
698
699
700
701
            return np.copy(data)
    return data


702
703
704
def _c_float_array(
    data: np.ndarray
) -> Tuple[_ctypes_float_ptr, int, np.ndarray]:
705
    """Get pointer of float numpy array / list."""
706
    if _is_1d_list(data):
wxchan's avatar
wxchan committed
707
        data = np.array(data, copy=False)
708
    if _is_numpy_1d_array(data):
709
        data = _convert_from_sliced_object(data)
710
        assert data.flags.c_contiguous
711
        ptr_data: _ctypes_float_ptr
wxchan's avatar
wxchan committed
712
713
        if data.dtype == np.float32:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
714
            type_data = _C_API_DTYPE_FLOAT32
wxchan's avatar
wxchan committed
715
716
        elif data.dtype == np.float64:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
717
            type_data = _C_API_DTYPE_FLOAT64
wxchan's avatar
wxchan committed
718
        else:
719
            raise TypeError(f"Expected np.float32 or np.float64, met type({data.dtype})")
wxchan's avatar
wxchan committed
720
    else:
721
        raise TypeError(f"Unknown type({type(data).__name__})")
722
    return (ptr_data, type_data, data)  # return `data` to avoid the temporary copy is freed
wxchan's avatar
wxchan committed
723

wxchan's avatar
wxchan committed
724

725
726
727
def _c_int_array(
    data: np.ndarray
) -> Tuple[_ctypes_int_ptr, int, np.ndarray]:
728
    """Get pointer of int numpy array / list."""
729
    if _is_1d_list(data):
wxchan's avatar
wxchan committed
730
        data = np.array(data, copy=False)
731
    if _is_numpy_1d_array(data):
732
        data = _convert_from_sliced_object(data)
733
        assert data.flags.c_contiguous
734
        ptr_data: _ctypes_int_ptr
wxchan's avatar
wxchan committed
735
736
        if data.dtype == np.int32:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
737
            type_data = _C_API_DTYPE_INT32
wxchan's avatar
wxchan committed
738
739
        elif data.dtype == np.int64:
            ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int64))
740
            type_data = _C_API_DTYPE_INT64
wxchan's avatar
wxchan committed
741
        else:
742
            raise TypeError(f"Expected np.int32 or np.int64, met type({data.dtype})")
wxchan's avatar
wxchan committed
743
    else:
744
        raise TypeError(f"Unknown type({type(data).__name__})")
745
    return (ptr_data, type_data, data)  # return `data` to avoid the temporary copy is freed
wxchan's avatar
wxchan committed
746

wxchan's avatar
wxchan committed
747

748
def _is_allowed_numpy_dtype(dtype: type) -> bool:
749
    float128 = getattr(np, 'float128', type(None))
750
751
752
753
    return (
        issubclass(dtype, (np.integer, np.floating, np.bool_))
        and not issubclass(dtype, (np.timedelta64, float128))
    )
754
755


756
def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
757
758
    bad_pandas_dtypes = [
        f'{column_name}: {pandas_dtype}'
759
        for column_name, pandas_dtype in pandas_dtypes_series.items()
760
        if not _is_allowed_numpy_dtype(pandas_dtype.type)
761
762
763
764
    ]
    if bad_pandas_dtypes:
        raise ValueError('pandas dtypes must be int, float or bool.\n'
                         f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')
765
766


767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
def _pandas_to_numpy(
    data: pd_DataFrame,
    target_dtype: "np.typing.DTypeLike"
) -> np.ndarray:
    _check_for_bad_pandas_dtypes(data.dtypes)
    try:
        # most common case (no nullable dtypes)
        return data.to_numpy(dtype=target_dtype, copy=False)
    except TypeError:
        # 1.0 <= pd version < 1.1 and nullable dtypes, least common case
        # raises error because array is casted to type(pd.NA) and there's no na_value argument
        return data.astype(target_dtype, copy=False).values
    except ValueError:
        # data has nullable dtypes, but we can specify na_value argument and copy will be made
        return data.to_numpy(dtype=target_dtype, na_value=np.nan)


784
def _data_from_pandas(
785
786
787
    data: pd_DataFrame,
    feature_name: _LGBM_FeatureNameConfiguration,
    categorical_feature: _LGBM_CategoricalFeatureConfiguration,
788
    pandas_categorical: Optional[List[List]]
789
) -> Tuple[np.ndarray, List[str], Union[List[str], List[int]], List[List]]:
790
791
792
    if len(data.shape) != 2 or data.shape[0] < 1:
        raise ValueError('Input data must be 2 dimensional and non empty.')

793
794
795
796
    # take shallow copy in case we modify categorical columns
    # whole column modifications don't change the original df
    data = data.copy(deep=False)

797
798
799
800
801
802
    # determine feature names
    if feature_name == 'auto':
        feature_name = [str(col) for col in data.columns]

    # determine categorical features
    cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)]
803
    cat_cols_not_ordered: List[str] = [col for col in cat_cols if not data[col].cat.ordered]
804
805
    if pandas_categorical is None:  # train dataset
        pandas_categorical = [list(data[col].cat.categories) for col in cat_cols]
806
    else:
807
808
809
810
811
812
813
        if len(cat_cols) != len(pandas_categorical):
            raise ValueError('train and valid dataset categorical_feature do not match.')
        for col, category in zip(cat_cols, pandas_categorical):
            if list(data[col].cat.categories) != list(category):
                data[col] = data[col].cat.set_categories(category)
    if len(cat_cols):  # cat_cols is list
        data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan})
814
815
816

    # use cat cols from DataFrame
    if categorical_feature == 'auto':
817
818
819
        categorical_feature = cat_cols_not_ordered

    df_dtypes = [dtype.type for dtype in data.dtypes]
820
821
    # so that the target dtype considers floats
    df_dtypes.append(np.float32)
822
    target_dtype = np.result_type(*df_dtypes)
823
824
825
826
827
828
829

    return (
        _pandas_to_numpy(data, target_dtype=target_dtype),
        feature_name,
        categorical_feature,
        pandas_categorical
    )
830
831


832
833
834
835
def _dump_pandas_categorical(
    pandas_categorical: Optional[List[List]],
    file_name: Optional[Union[str, Path]] = None
) -> str:
836
    categorical_json = json.dumps(pandas_categorical, default=_json_default_with_numpy)
837
    pandas_str = f'\npandas_categorical:{categorical_json}\n'
838
839
840
841
842
843
    if file_name is not None:
        with open(file_name, 'a') as f:
            f.write(pandas_str)
    return pandas_str


844
845
846
def _load_pandas_categorical(
    file_name: Optional[Union[str, Path]] = None,
    model_str: Optional[str] = None
847
) -> Optional[List[List]]:
848
849
    pandas_key = 'pandas_categorical:'
    offset = -len(pandas_key)
850
    if file_name is not None:
851
        max_offset = -getsize(file_name)
852
853
854
855
        with open(file_name, 'rb') as f:
            while True:
                if offset < max_offset:
                    offset = max_offset
856
                f.seek(offset, SEEK_END)
857
858
859
860
                lines = f.readlines()
                if len(lines) >= 2:
                    break
                offset *= 2
861
        last_line = lines[-1].decode('utf-8').strip()
862
        if not last_line.startswith(pandas_key):
863
            last_line = lines[-2].decode('utf-8').strip()
864
    elif model_str is not None:
865
866
867
868
869
870
        idx = model_str.rfind('\n', 0, offset)
        last_line = model_str[idx:].strip()
    if last_line.startswith(pandas_key):
        return json.loads(last_line[len(pandas_key):])
    else:
        return None
871
872


873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
class Sequence(abc.ABC):
    """
    Generic data access interface.

    Object should support the following operations:

    .. code-block::

        # Get total row number.
        >>> len(seq)
        # Random access by row index. Used for data sampling.
        >>> seq[10]
        # Range data access. Used to read data in batch when constructing Dataset.
        >>> seq[0:100]
        # Optionally specify batch_size to control range data read size.
        >>> seq.batch_size

    - With random access, **data sampling does not need to go through all data**.
    - With range data access, there's **no need to read all data into memory thus reduce memory usage**.

893
894
    .. versionadded:: 3.3.0

895
896
897
898
899
900
901
902
903
    Attributes
    ----------
    batch_size : int
        Default size of a batch.
    """

    batch_size = 4096  # Defaults to read 4K rows in each batch.

    @abc.abstractmethod
904
    def __getitem__(self, idx: Union[int, slice, List[int]]) -> np.ndarray:
905
906
907
908
909
910
911
        """Return data for given row index.

        A basic implementation should look like this:

        .. code-block:: python

            if isinstance(idx, numbers.Integral):
912
                return self._get_one_line(idx)
913
            elif isinstance(idx, slice):
914
915
                return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)])
            elif isinstance(idx, list):
916
                # Only required if using ``Dataset.subset()``.
917
                return np.array([self._get_one_line(i) for i in idx])
918
            else:
919
                raise TypeError(f"Sequence index must be integer, slice or list, got {type(idx).__name__}")
920
921
922

        Parameters
        ----------
923
        idx : int, slice[int], list[int]
924
925
926
927
            Item index.

        Returns
        -------
928
        result : numpy 1-D array or numpy 2-D array
929
            1-D array if idx is int, 2-D array if idx is slice or list.
930
931
932
933
934
935
936
937
938
        """
        raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()")

    @abc.abstractmethod
    def __len__(self) -> int:
        """Return row count of this sequence."""
        raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __len__()")


939
class _InnerPredictor:
940
941
942
943
944
    """_InnerPredictor of LightGBM.

    Not exposed to user.
    Used only for prediction, usually used for continued training.

Nikita Titov's avatar
Nikita Titov committed
945
946
947
    .. note::

        Can be converted from Booster, but cannot be converted to Booster.
Guolin Ke's avatar
Guolin Ke committed
948
    """
949

950
951
    def __init__(
        self,
952
953
954
955
        booster_handle: _BoosterHandle,
        pandas_categorical: Optional[List[List]],
        pred_parameter: Dict[str, Any],
        manage_handle: bool
956
    ):
957
        """Initialize the _InnerPredictor.
wxchan's avatar
wxchan committed
958
959
960

        Parameters
        ----------
961
        booster_handle : object
962
            Handle of Booster.
963
964
965
966
        pandas_categorical : list of list, or None
            If provided, list of categories for ``pandas`` categorical columns.
            Where the ``i``th element of the list contains the categories for the ``i``th categorical feature.
        pred_parameter : dict
967
            Other parameters for the prediction.
968
969
        manage_handle : bool
            If ``True``, free the corresponding Booster on the C++ side when this Python object is deleted.
wxchan's avatar
wxchan committed
970
        """
971
972
973
974
975
976
977
978
        self._handle = booster_handle
        self.__is_manage_handle = manage_handle
        self.pandas_categorical = pandas_categorical
        self.pred_parameter = _param_dict_to_str(pred_parameter)

        out_num_class = ctypes.c_int(0)
        _safe_call(
            _LIB.LGBM_BoosterGetNumClasses(
979
                self._handle,
980
981
982
983
                ctypes.byref(out_num_class)
            )
        )
        self.num_class = out_num_class.value
wxchan's avatar
wxchan committed
984

985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
    @classmethod
    def from_booster(
        cls,
        booster: "Booster",
        pred_parameter: Dict[str, Any]
    ) -> "_InnerPredictor":
        """Initialize an ``_InnerPredictor`` from a ``Booster``.

        Parameters
        ----------
        booster : Booster
            Booster.
        pred_parameter : dict
            Other parameters for the prediction.
        """
        out_cur_iter = ctypes.c_int(0)
        _safe_call(
            _LIB.LGBM_BoosterGetCurrentIteration(
                booster._handle,
                ctypes.byref(out_cur_iter)
            )
        )
        return cls(
            booster_handle=booster._handle,
            pandas_categorical=booster.pandas_categorical,
            pred_parameter=pred_parameter,
            manage_handle=False
        )

    @classmethod
    def from_model_file(
        cls,
        model_file: Union[str, Path],
        pred_parameter: Dict[str, Any]
    ) -> "_InnerPredictor":
        """Initialize an ``_InnerPredictor`` from a text file containing a LightGBM model.

        Parameters
        ----------
        model_file : str or pathlib.Path
            Path to the model file.
        pred_parameter : dict
            Other parameters for the prediction.
        """
        booster_handle = ctypes.c_void_p()
        out_num_iterations = ctypes.c_int(0)
        _safe_call(
            _LIB.LGBM_BoosterCreateFromModelfile(
                _c_str(str(model_file)),
                ctypes.byref(out_num_iterations),
                ctypes.byref(booster_handle)
            )
        )
        return cls(
            booster_handle=booster_handle,
            pandas_categorical=_load_pandas_categorical(file_name=model_file),
            pred_parameter=pred_parameter,
            manage_handle=True
        )
cbecker's avatar
cbecker committed
1044

1045
    def __del__(self) -> None:
1046
1047
        try:
            if self.__is_manage_handle:
1048
                _safe_call(_LIB.LGBM_BoosterFree(self._handle))
1049
1050
        except AttributeError:
            pass
wxchan's avatar
wxchan committed
1051

1052
    def __getstate__(self) -> Dict[str, Any]:
1053
1054
        this = self.__dict__.copy()
        this.pop('handle', None)
1055
        this.pop('_handle', None)
1056
1057
        return this

1058
1059
    def predict(
        self,
1060
        data: _LGBM_PredictDataType,
1061
1062
1063
1064
1065
1066
1067
        start_iteration: int = 0,
        num_iteration: int = -1,
        raw_score: bool = False,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        data_has_header: bool = False,
        validate_features: bool = False
1068
    ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
1069
        """Predict logic.
wxchan's avatar
wxchan committed
1070
1071
1072

        Parameters
        ----------
1073
        data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame or scipy.sparse
1074
            Data source for prediction.
1075
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
1076
1077
        start_iteration : int, optional (default=0)
            Start index of the iteration to predict.
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
        num_iteration : int, optional (default=-1)
            Iteration used for prediction.
        raw_score : bool, optional (default=False)
            Whether to predict raw scores.
        pred_leaf : bool, optional (default=False)
            Whether to predict leaf index.
        pred_contrib : bool, optional (default=False)
            Whether to predict feature contributions.
        data_has_header : bool, optional (default=False)
            Whether data has header.
            Used only for txt data.
1089
1090
1091
        validate_features : bool, optional (default=False)
            If True, ensure that the features used to predict match the ones used to train.
            Used only if data is pandas DataFrame.
wxchan's avatar
wxchan committed
1092

1093
1094
            .. versionadded:: 4.0.0

wxchan's avatar
wxchan committed
1095
1096
        Returns
        -------
1097
        result : numpy array, scipy.sparse or list of scipy.sparse
1098
            Prediction result.
1099
            Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
wxchan's avatar
wxchan committed
1100
        """
wxchan's avatar
wxchan committed
1101
        if isinstance(data, Dataset):
1102
            raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
1103
1104
1105
1106
1107
1108
        elif isinstance(data, pd_DataFrame) and validate_features:
            data_names = [str(x) for x in data.columns]
            ptr_names = (ctypes.c_char_p * len(data_names))()
            ptr_names[:] = [x.encode('utf-8') for x in data_names]
            _safe_call(
                _LIB.LGBM_BoosterValidateFeatureNames(
1109
                    self._handle,
1110
1111
1112
1113
                    ptr_names,
                    ctypes.c_int(len(data_names)),
                )
            )
1114
1115
1116
1117
1118
1119
1120
1121
1122

        if isinstance(data, pd_DataFrame):
            data = _data_from_pandas(
                data=data,
                feature_name="auto",
                categorical_feature="auto",
                pandas_categorical=self.pandas_categorical
            )[0]

1123
        predict_type = _C_API_PREDICT_NORMAL
wxchan's avatar
wxchan committed
1124
        if raw_score:
1125
            predict_type = _C_API_PREDICT_RAW_SCORE
wxchan's avatar
wxchan committed
1126
        if pred_leaf:
1127
            predict_type = _C_API_PREDICT_LEAF_INDEX
1128
        if pred_contrib:
1129
            predict_type = _C_API_PREDICT_CONTRIB
cbecker's avatar
cbecker committed
1130

1131
        if isinstance(data, (str, Path)):
1132
            with _TempFile() as f:
wxchan's avatar
wxchan committed
1133
                _safe_call(_LIB.LGBM_BoosterPredictForFile(
1134
                    self._handle,
1135
                    _c_str(str(data)),
1136
                    ctypes.c_int(data_has_header),
Guolin Ke's avatar
Guolin Ke committed
1137
                    ctypes.c_int(predict_type),
1138
                    ctypes.c_int(start_iteration),
Guolin Ke's avatar
Guolin Ke committed
1139
                    ctypes.c_int(num_iteration),
1140
1141
                    _c_str(self.pred_parameter),
                    _c_str(f.name)))
1142
1143
                preds = np.loadtxt(f.name, dtype=np.float64)
                nrow = preds.shape[0]
wxchan's avatar
wxchan committed
1144
        elif isinstance(data, scipy.sparse.csr_matrix):
1145
1146
1147
1148
1149
1150
            preds, nrow = self.__pred_for_csr(
                csr=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
Guolin Ke's avatar
Guolin Ke committed
1151
        elif isinstance(data, scipy.sparse.csc_matrix):
1152
1153
1154
1155
1156
1157
            preds, nrow = self.__pred_for_csc(
                csc=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
wxchan's avatar
wxchan committed
1158
        elif isinstance(data, np.ndarray):
1159
1160
1161
1162
1163
1164
            preds, nrow = self.__pred_for_np2d(
                mat=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
1165
1166
1167
1168
1169
1170
1171
        elif _is_pyarrow_table(data):
            preds, nrow = self.__pred_for_pyarrow_table(
                table=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
1172
1173
1174
        elif isinstance(data, list):
            try:
                data = np.array(data)
1175
1176
            except BaseException as err:
                raise ValueError('Cannot convert data list to numpy array.') from err
1177
1178
1179
1180
1181
1182
            preds, nrow = self.__pred_for_np2d(
                mat=data,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
1183
        elif isinstance(data, dt_DataTable):
1184
1185
1186
1187
1188
1189
            preds, nrow = self.__pred_for_np2d(
                mat=data.to_numpy(),
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
wxchan's avatar
wxchan committed
1190
1191
        else:
            try:
1192
                _log_warning('Converting data to scipy sparse matrix.')
wxchan's avatar
wxchan committed
1193
                csr = scipy.sparse.csr_matrix(data)
1194
1195
            except BaseException as err:
                raise TypeError(f'Cannot predict data for type {type(data).__name__}') from err
1196
1197
1198
1199
1200
1201
            preds, nrow = self.__pred_for_csr(
                csr=csr,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
wxchan's avatar
wxchan committed
1202
1203
        if pred_leaf:
            preds = preds.astype(np.int32)
1204
        is_sparse = isinstance(preds, scipy.sparse.spmatrix) or isinstance(preds, list)
1205
        if not is_sparse and preds.size != nrow:
wxchan's avatar
wxchan committed
1206
            if preds.size % nrow == 0:
1207
                preds = preds.reshape(nrow, -1)
wxchan's avatar
wxchan committed
1208
            else:
1209
                raise ValueError(f'Length of predict result ({preds.size}) cannot be divide nrow ({nrow})')
wxchan's avatar
wxchan committed
1210
1211
        return preds

1212
1213
1214
1215
1216
1217
1218
    def __get_num_preds(
        self,
        start_iteration: int,
        num_iteration: int,
        nrow: int,
        predict_type: int
    ) -> int:
1219
        """Get size of prediction result."""
1220
        if nrow > _MAX_INT32:
1221
            raise LightGBMError('LightGBM cannot perform prediction for data '
1222
                                f'with number of rows greater than MAX_INT32 ({_MAX_INT32}).\n'
1223
                                'You can split your data into chunks '
1224
                                'and then concatenate predictions for them')
Guolin Ke's avatar
Guolin Ke committed
1225
1226
        n_preds = ctypes.c_int64(0)
        _safe_call(_LIB.LGBM_BoosterCalcNumPredict(
1227
            self._handle,
Guolin Ke's avatar
Guolin Ke committed
1228
1229
            ctypes.c_int(nrow),
            ctypes.c_int(predict_type),
1230
            ctypes.c_int(start_iteration),
Guolin Ke's avatar
Guolin Ke committed
1231
            ctypes.c_int(num_iteration),
Guolin Ke's avatar
Guolin Ke committed
1232
1233
            ctypes.byref(n_preds)))
        return n_preds.value
wxchan's avatar
wxchan committed
1234

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
    def __inner_predict_np2d(
        self,
        mat: np.ndarray,
        start_iteration: int,
        num_iteration: int,
        predict_type: int,
        preds: Optional[np.ndarray]
    ) -> Tuple[np.ndarray, int]:
        if mat.dtype == np.float32 or mat.dtype == np.float64:
            data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False)
        else:  # change non-float data to float data, need to copy
            data = np.array(mat.reshape(mat.size), dtype=np.float32)
        ptr_data, type_ptr_data, _ = _c_float_array(data)
1248
1249
1250
1251
1252
1253
        n_preds = self.__get_num_preds(
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            nrow=mat.shape[0],
            predict_type=predict_type
        )
1254
1255
1256
1257
1258
1259
        if preds is None:
            preds = np.empty(n_preds, dtype=np.float64)
        elif len(preds.shape) != 1 or len(preds) != n_preds:
            raise ValueError("Wrong length of pre-allocated predict array")
        out_num_preds = ctypes.c_int64(0)
        _safe_call(_LIB.LGBM_BoosterPredictForMat(
1260
            self._handle,
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
            ptr_data,
            ctypes.c_int(type_ptr_data),
            ctypes.c_int32(mat.shape[0]),
            ctypes.c_int32(mat.shape[1]),
            ctypes.c_int(_C_API_IS_ROW_MAJOR),
            ctypes.c_int(predict_type),
            ctypes.c_int(start_iteration),
            ctypes.c_int(num_iteration),
            _c_str(self.pred_parameter),
            ctypes.byref(out_num_preds),
            preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
        if n_preds != out_num_preds.value:
            raise ValueError("Wrong length for predict results")
        return preds, mat.shape[0]

    def __pred_for_np2d(
        self,
        mat: np.ndarray,
        start_iteration: int,
        num_iteration: int,
        predict_type: int
    ) -> Tuple[np.ndarray, int]:
1283
        """Predict for a 2-D numpy matrix."""
wxchan's avatar
wxchan committed
1284
        if len(mat.shape) != 2:
1285
            raise ValueError('Input numpy.ndarray or list must be 2 dimensional')
wxchan's avatar
wxchan committed
1286

1287
        nrow = mat.shape[0]
1288
1289
        if nrow > _MAX_INT32:
            sections = np.arange(start=_MAX_INT32, stop=nrow, step=_MAX_INT32)
1290
            # __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal
1291
            n_preds = [self.__get_num_preds(start_iteration, num_iteration, i, predict_type) for i in np.diff([0] + list(sections) + [nrow])]
1292
            n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
1293
            preds = np.empty(sum(n_preds), dtype=np.float64)
1294
1295
            for chunk, (start_idx_pred, end_idx_pred) in zip(np.array_split(mat, sections),
                                                             zip(n_preds_sections, n_preds_sections[1:])):
1296
                # avoid memory consumption by arrays concatenation operations
1297
1298
1299
1300
1301
1302
1303
                self.__inner_predict_np2d(
                    mat=chunk,
                    start_iteration=start_iteration,
                    num_iteration=num_iteration,
                    predict_type=predict_type,
                    preds=preds[start_idx_pred:end_idx_pred]
                )
1304
            return preds, nrow
wxchan's avatar
wxchan committed
1305
        else:
1306
1307
1308
1309
1310
1311
1312
            return self.__inner_predict_np2d(
                mat=mat,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type,
                preds=None
            )
wxchan's avatar
wxchan committed
1313

1314
1315
1316
    def __create_sparse_native(
        self,
        cs: Union[scipy.sparse.csc_matrix, scipy.sparse.csr_matrix],
1317
1318
1319
1320
1321
1322
        out_shape: np.ndarray,
        out_ptr_indptr: "ctypes._Pointer",
        out_ptr_indices: "ctypes._Pointer",
        out_ptr_data: "ctypes._Pointer",
        indptr_type: int,
        data_type: int,
1323
        is_csr: bool
1324
    ) -> Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]]:
1325
1326
1327
        # create numpy array from output arrays
        data_indices_len = out_shape[0]
        indptr_len = out_shape[1]
1328
        if indptr_type == _C_API_DTYPE_INT32:
1329
            out_indptr = _cint32_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len)
1330
        elif indptr_type == _C_API_DTYPE_INT64:
1331
            out_indptr = _cint64_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len)
1332
1333
        else:
            raise TypeError("Expected int32 or int64 type for indptr")
1334
        if data_type == _C_API_DTYPE_FLOAT32:
1335
            out_data = _cfloat32_array_to_numpy(cptr=out_ptr_data, length=data_indices_len)
1336
        elif data_type == _C_API_DTYPE_FLOAT64:
1337
            out_data = _cfloat64_array_to_numpy(cptr=out_ptr_data, length=data_indices_len)
1338
1339
        else:
            raise TypeError("Expected float32 or float64 type for data")
1340
        out_indices = _cint32_array_to_numpy(cptr=out_ptr_indices, length=data_indices_len)
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
        # break up indptr based on number of rows (note more than one matrix in multiclass case)
        per_class_indptr_shape = cs.indptr.shape[0]
        # for CSC there is extra column added
        if not is_csr:
            per_class_indptr_shape += 1
        out_indptr_arrays = np.split(out_indptr, out_indptr.shape[0] / per_class_indptr_shape)
        # reformat output into a csr or csc matrix or list of csr or csc matrices
        cs_output_matrices = []
        offset = 0
        for cs_indptr in out_indptr_arrays:
            matrix_indptr_len = cs_indptr[cs_indptr.shape[0] - 1]
            cs_indices = out_indices[offset + cs_indptr[0]:offset + matrix_indptr_len]
            cs_data = out_data[offset + cs_indptr[0]:offset + matrix_indptr_len]
            offset += matrix_indptr_len
            # same shape as input csr or csc matrix except extra column for expected value
            cs_shape = [cs.shape[0], cs.shape[1] + 1]
            # note: make sure we copy data as it will be deallocated next
            if is_csr:
                cs_output_matrices.append(scipy.sparse.csr_matrix((cs_data, cs_indices, cs_indptr), cs_shape))
            else:
                cs_output_matrices.append(scipy.sparse.csc_matrix((cs_data, cs_indices, cs_indptr), cs_shape))
        # free the temporary native indptr, indices, and data
        _safe_call(_LIB.LGBM_BoosterFreePredictSparse(out_ptr_indptr, out_ptr_indices, out_ptr_data,
                                                      ctypes.c_int(indptr_type), ctypes.c_int(data_type)))
        if len(cs_output_matrices) == 1:
            return cs_output_matrices[0]
        return cs_output_matrices

1369
1370
1371
1372
1373
1374
1375
1376
1377
    def __inner_predict_csr(
        self,
        csr: scipy.sparse.csr_matrix,
        start_iteration: int,
        num_iteration: int,
        predict_type: int,
        preds: Optional[np.ndarray]
    ) -> Tuple[np.ndarray, int]:
        nrow = len(csr.indptr) - 1
1378
1379
1380
1381
1382
1383
        n_preds = self.__get_num_preds(
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            nrow=nrow,
            predict_type=predict_type
        )
1384
1385
1386
1387
1388
        if preds is None:
            preds = np.empty(n_preds, dtype=np.float64)
        elif len(preds.shape) != 1 or len(preds) != n_preds:
            raise ValueError("Wrong length of pre-allocated predict array")
        out_num_preds = ctypes.c_int64(0)
wxchan's avatar
wxchan committed
1389

1390
1391
        ptr_indptr, type_ptr_indptr, _ = _c_int_array(csr.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
1392

1393
1394
1395
1396
        assert csr.shape[1] <= _MAX_INT32
        csr_indices = csr.indices.astype(np.int32, copy=False)

        _safe_call(_LIB.LGBM_BoosterPredictForCSR(
1397
            self._handle,
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
            ptr_indptr,
            ctypes.c_int(type_ptr_indptr),
            csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
            ptr_data,
            ctypes.c_int(type_ptr_data),
            ctypes.c_int64(len(csr.indptr)),
            ctypes.c_int64(len(csr.data)),
            ctypes.c_int64(csr.shape[1]),
            ctypes.c_int(predict_type),
            ctypes.c_int(start_iteration),
            ctypes.c_int(num_iteration),
            _c_str(self.pred_parameter),
            ctypes.byref(out_num_preds),
            preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
        if n_preds != out_num_preds.value:
            raise ValueError("Wrong length for predict results")
        return preds, nrow

    def __inner_predict_csr_sparse(
        self,
        csr: scipy.sparse.csr_matrix,
        start_iteration: int,
        num_iteration: int,
        predict_type: int
1422
    ) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]:
1423
1424
1425
1426
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
        csr_indices = csr.indices.astype(np.int32, copy=False)
        matrix_type = _C_API_MATRIX_TYPE_CSR
1427
        out_ptr_indptr: _ctypes_int_ptr
1428
1429
1430
1431
1432
        if type_ptr_indptr == _C_API_DTYPE_INT32:
            out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)()
        else:
            out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)()
        out_ptr_indices = ctypes.POINTER(ctypes.c_int32)()
1433
        out_ptr_data: _ctypes_float_ptr
1434
1435
1436
1437
1438
1439
        if type_ptr_data == _C_API_DTYPE_FLOAT32:
            out_ptr_data = ctypes.POINTER(ctypes.c_float)()
        else:
            out_ptr_data = ctypes.POINTER(ctypes.c_double)()
        out_shape = np.empty(2, dtype=np.int64)
        _safe_call(_LIB.LGBM_BoosterPredictSparseOutput(
1440
            self._handle,
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
            ptr_indptr,
            ctypes.c_int(type_ptr_indptr),
            csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
            ptr_data,
            ctypes.c_int(type_ptr_data),
            ctypes.c_int64(len(csr.indptr)),
            ctypes.c_int64(len(csr.data)),
            ctypes.c_int64(csr.shape[1]),
            ctypes.c_int(predict_type),
            ctypes.c_int(start_iteration),
            ctypes.c_int(num_iteration),
            _c_str(self.pred_parameter),
            ctypes.c_int(matrix_type),
            out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
            ctypes.byref(out_ptr_indptr),
            ctypes.byref(out_ptr_indices),
            ctypes.byref(out_ptr_data)))
        matrices = self.__create_sparse_native(
            cs=csr,
            out_shape=out_shape,
            out_ptr_indptr=out_ptr_indptr,
            out_ptr_indices=out_ptr_indices,
            out_ptr_data=out_ptr_data,
            indptr_type=type_ptr_indptr,
            data_type=type_ptr_data,
            is_csr=True
        )
        nrow = len(csr.indptr) - 1
        return matrices, nrow

1471
1472
1473
1474
1475
1476
1477
    def __pred_for_csr(
        self,
        csr: scipy.sparse.csr_matrix,
        start_iteration: int,
        num_iteration: int,
        predict_type: int
    ) -> Tuple[np.ndarray, int]:
1478
        """Predict for a CSR data."""
1479
        if predict_type == _C_API_PREDICT_CONTRIB:
1480
1481
1482
1483
1484
1485
            return self.__inner_predict_csr_sparse(
                csr=csr,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
1486
        nrow = len(csr.indptr) - 1
1487
1488
        if nrow > _MAX_INT32:
            sections = [0] + list(np.arange(start=_MAX_INT32, stop=nrow, step=_MAX_INT32)) + [nrow]
1489
            # __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal
1490
            n_preds = [self.__get_num_preds(start_iteration, num_iteration, i, predict_type) for i in np.diff(sections)]
1491
            n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
1492
            preds = np.empty(sum(n_preds), dtype=np.float64)
1493
1494
            for (start_idx, end_idx), (start_idx_pred, end_idx_pred) in zip(zip(sections, sections[1:]),
                                                                            zip(n_preds_sections, n_preds_sections[1:])):
1495
                # avoid memory consumption by arrays concatenation operations
1496
1497
1498
1499
1500
1501
1502
                self.__inner_predict_csr(
                    csr=csr[start_idx:end_idx],
                    start_iteration=start_iteration,
                    num_iteration=num_iteration,
                    predict_type=predict_type,
                    preds=preds[start_idx_pred:end_idx_pred]
                )
1503
1504
            return preds, nrow
        else:
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
            return self.__inner_predict_csr(
                csr=csr,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type,
                preds=None
            )

    def __inner_predict_sparse_csc(
        self,
1515
1516
1517
1518
        csc: scipy.sparse.csc_matrix,
        start_iteration: int,
        num_iteration: int,
        predict_type: int
1519
1520
1521
1522
1523
    ):
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
        csc_indices = csc.indices.astype(np.int32, copy=False)
        matrix_type = _C_API_MATRIX_TYPE_CSC
1524
        out_ptr_indptr: _ctypes_int_ptr
1525
1526
1527
1528
1529
        if type_ptr_indptr == _C_API_DTYPE_INT32:
            out_ptr_indptr = ctypes.POINTER(ctypes.c_int32)()
        else:
            out_ptr_indptr = ctypes.POINTER(ctypes.c_int64)()
        out_ptr_indices = ctypes.POINTER(ctypes.c_int32)()
1530
        out_ptr_data: _ctypes_float_ptr
1531
1532
1533
1534
1535
1536
        if type_ptr_data == _C_API_DTYPE_FLOAT32:
            out_ptr_data = ctypes.POINTER(ctypes.c_float)()
        else:
            out_ptr_data = ctypes.POINTER(ctypes.c_double)()
        out_shape = np.empty(2, dtype=np.int64)
        _safe_call(_LIB.LGBM_BoosterPredictSparseOutput(
1537
            self._handle,
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
            ptr_indptr,
            ctypes.c_int(type_ptr_indptr),
            csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
            ptr_data,
            ctypes.c_int(type_ptr_data),
            ctypes.c_int64(len(csc.indptr)),
            ctypes.c_int64(len(csc.data)),
            ctypes.c_int64(csc.shape[0]),
            ctypes.c_int(predict_type),
            ctypes.c_int(start_iteration),
            ctypes.c_int(num_iteration),
            _c_str(self.pred_parameter),
            ctypes.c_int(matrix_type),
            out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
            ctypes.byref(out_ptr_indptr),
            ctypes.byref(out_ptr_indices),
            ctypes.byref(out_ptr_data)))
        matrices = self.__create_sparse_native(
            cs=csc,
            out_shape=out_shape,
            out_ptr_indptr=out_ptr_indptr,
            out_ptr_indices=out_ptr_indices,
            out_ptr_data=out_ptr_data,
            indptr_type=type_ptr_indptr,
            data_type=type_ptr_data,
            is_csr=False
        )
        nrow = csc.shape[0]
        return matrices, nrow
Guolin Ke's avatar
Guolin Ke committed
1567

1568
1569
1570
1571
1572
1573
1574
    def __pred_for_csc(
        self,
        csc: scipy.sparse.csc_matrix,
        start_iteration: int,
        num_iteration: int,
        predict_type: int
    ) -> Tuple[np.ndarray, int]:
1575
        """Predict for a CSC data."""
Guolin Ke's avatar
Guolin Ke committed
1576
        nrow = csc.shape[0]
1577
        if nrow > _MAX_INT32:
1578
1579
1580
1581
1582
1583
            return self.__pred_for_csr(
                csr=csc.tocsr(),
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
1584
        if predict_type == _C_API_PREDICT_CONTRIB:
1585
1586
1587
1588
1589
1590
            return self.__inner_predict_sparse_csc(
                csc=csc,
                start_iteration=start_iteration,
                num_iteration=num_iteration,
                predict_type=predict_type
            )
1591
1592
1593
1594
1595
1596
        n_preds = self.__get_num_preds(
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            nrow=nrow,
            predict_type=predict_type
        )
1597
        preds = np.empty(n_preds, dtype=np.float64)
Guolin Ke's avatar
Guolin Ke committed
1598
1599
        out_num_preds = ctypes.c_int64(0)

1600
1601
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
Guolin Ke's avatar
Guolin Ke committed
1602

1603
        assert csc.shape[0] <= _MAX_INT32
1604
        csc_indices = csc.indices.astype(np.int32, copy=False)
1605

Guolin Ke's avatar
Guolin Ke committed
1606
        _safe_call(_LIB.LGBM_BoosterPredictForCSC(
1607
            self._handle,
Guolin Ke's avatar
Guolin Ke committed
1608
            ptr_indptr,
1609
            ctypes.c_int(type_ptr_indptr),
1610
            csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
Guolin Ke's avatar
Guolin Ke committed
1611
            ptr_data,
Guolin Ke's avatar
Guolin Ke committed
1612
1613
1614
1615
1616
            ctypes.c_int(type_ptr_data),
            ctypes.c_int64(len(csc.indptr)),
            ctypes.c_int64(len(csc.data)),
            ctypes.c_int64(csc.shape[0]),
            ctypes.c_int(predict_type),
1617
            ctypes.c_int(start_iteration),
Guolin Ke's avatar
Guolin Ke committed
1618
            ctypes.c_int(num_iteration),
1619
            _c_str(self.pred_parameter),
Guolin Ke's avatar
Guolin Ke committed
1620
            ctypes.byref(out_num_preds),
wxchan's avatar
wxchan committed
1621
            preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
wxchan's avatar
wxchan committed
1622
        if n_preds != out_num_preds.value:
1623
            raise ValueError("Wrong length for predict results")
wxchan's avatar
wxchan committed
1624
        return preds, nrow
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
    
    def __pred_for_pyarrow_table(
        self,
        table: pa_Table,
        start_iteration: int,
        num_iteration: int,
        predict_type: int
    ) -> Tuple[np.ndarray, int]:
        """Predict for a PyArrow table."""
        if not PYARROW_INSTALLED:
            raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")

        # Check that the input is valid: we only handle numbers (for now)
        if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
            raise ValueError("Arrow table may only have integer or floating point datatypes")

        # Prepare prediction output array
        n_preds = self.__get_num_preds(
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            nrow=table.num_rows,
            predict_type=predict_type
        )
        preds = np.empty(n_preds, dtype=np.float64)
        out_num_preds = ctypes.c_int64(0)

        # Export Arrow table to C and run prediction
        c_array = _export_arrow_to_c(table)
        _safe_call(_LIB.LGBM_BoosterPredictForArrow(
            self._handle,
            ctypes.c_int64(c_array.n_chunks),
            ctypes.c_void_p(c_array.chunks_ptr),
            ctypes.c_void_p(c_array.schema_ptr),
            ctypes.c_int(predict_type),
            ctypes.c_int(start_iteration),
            ctypes.c_int(num_iteration),
            _c_str(self.pred_parameter),
            ctypes.byref(out_num_preds),
            preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
        if n_preds != out_num_preds.value:
            raise ValueError("Wrong length for predict results")
        return preds, table.num_rows
wxchan's avatar
wxchan committed
1667

1668
    def current_iteration(self) -> int:
1669
1670
1671
1672
1673
1674
1675
1676
1677
        """Get the index of the current iteration.

        Returns
        -------
        cur_iter : int
            The index of the current iteration.
        """
        out_cur_iter = ctypes.c_int(0)
        _safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
1678
            self._handle,
1679
1680
1681
            ctypes.byref(out_cur_iter)))
        return out_cur_iter.value

wxchan's avatar
wxchan committed
1682

1683
class Dataset:
wxchan's avatar
wxchan committed
1684
    """Dataset in LightGBM."""
1685

1686
1687
    def __init__(
        self,
1688
        data: _LGBM_TrainDataType,
1689
        label: Optional[_LGBM_LabelType] = None,
1690
        reference: Optional["Dataset"] = None,
1691
1692
1693
        weight: Optional[_LGBM_WeightType] = None,
        group: Optional[_LGBM_GroupType] = None,
        init_score: Optional[_LGBM_InitScoreType] = None,
1694
1695
        feature_name: _LGBM_FeatureNameConfiguration = 'auto',
        categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
1696
        params: Optional[Dict[str, Any]] = None,
1697
1698
        free_raw_data: bool = True,
        position: Optional[_LGBM_PositionType] = None,
1699
    ):
1700
        """Initialize Dataset.
1701

wxchan's avatar
wxchan committed
1702
1703
        Parameters
        ----------
1704
        data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table
wxchan's avatar
wxchan committed
1705
            Data source of Dataset.
1706
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
1707
        label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
1708
1709
1710
            Label of the data.
        reference : Dataset or None, optional (default=None)
            If this is Dataset for validation, training data should be used as reference.
1711
        weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
1712
            Weight for each instance. Weights should be non-negative.
1713
        group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
1714
1715
1716
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
1717
1718
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
1719
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
1720
            Init score for Dataset.
1721
        feature_name : list of str, or 'auto', optional (default="auto")
1722
            Feature names.
1723
            If 'auto' and data is pandas DataFrame or pyarrow Table, data columns names are used.
1724
        categorical_feature : list of str or int, or 'auto', optional (default="auto")
1725
1726
            Categorical features.
            If list of int, interpreted as indices.
1727
            If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
1728
            If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
1729
            All values in categorical features will be cast to int32 and thus should be less than int32 max value (2147483647).
1730
            Large values could be memory consuming. Consider using consecutive integers starting from zero.
1731
            All negative values in categorical features will be treated as missing values.
1732
            The output cannot be monotonically constrained with respect to a categorical feature.
1733
            Floating point numbers in categorical features will be rounded towards 0.
Nikita Titov's avatar
Nikita Titov committed
1734
        params : dict or None, optional (default=None)
1735
            Other parameters for Dataset.
Nikita Titov's avatar
Nikita Titov committed
1736
        free_raw_data : bool, optional (default=True)
1737
            If True, raw data is freed after constructing inner Dataset.
1738
1739
        position : numpy 1-D array, pandas Series or None, optional (default=None)
            Position of items used in unbiased learning-to-rank task.
wxchan's avatar
wxchan committed
1740
        """
1741
        self._handle: Optional[_DatasetHandle] = None
wxchan's avatar
wxchan committed
1742
1743
1744
1745
1746
        self.data = data
        self.label = label
        self.reference = reference
        self.weight = weight
        self.group = group
1747
        self.position = position
1748
        self.init_score = init_score
1749
1750
        self.feature_name: _LGBM_FeatureNameConfiguration = feature_name
        self.categorical_feature: _LGBM_CategoricalFeatureConfiguration = categorical_feature
1751
        self.params = deepcopy(params)
wxchan's avatar
wxchan committed
1752
        self.free_raw_data = free_raw_data
1753
        self.used_indices: Optional[List[int]] = None
1754
        self._need_slice = True
1755
        self._predictor: Optional[_InnerPredictor] = None
1756
        self.pandas_categorical: Optional[List[List]] = None
1757
        self._params_back_up = None
1758
        self.version = 0
1759
        self._start_row = 0  # Used when pushing rows one by one.
wxchan's avatar
wxchan committed
1760

1761
    def __del__(self) -> None:
1762
1763
1764
1765
        try:
            self._free_handle()
        except AttributeError:
            pass
1766

1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
    def _create_sample_indices(self, total_nrow: int) -> np.ndarray:
        """Get an array of randomly chosen indices from this ``Dataset``.

        Indices are sampled without replacement.

        Parameters
        ----------
        total_nrow : int
            Total number of rows to sample from.
            If this value is greater than the value of parameter ``bin_construct_sample_cnt``, only ``bin_construct_sample_cnt`` indices will be used.
            If Dataset has multiple input data, this should be the sum of rows of every file.

        Returns
        -------
        indices : numpy array
            Indices for sampled data.
        """
1784
        param_str = _param_dict_to_str(self.get_params())
1785
1786
        sample_cnt = _get_sample_count(total_nrow, param_str)
        indices = np.empty(sample_cnt, dtype=np.int32)
1787
        ptr_data, _, _ = _c_int_array(indices)
1788
1789
1790
1791
        actual_sample_cnt = ctypes.c_int32(0)

        _safe_call(_LIB.LGBM_SampleIndices(
            ctypes.c_int32(total_nrow),
1792
            _c_str(param_str),
1793
1794
1795
            ptr_data,
            ctypes.byref(actual_sample_cnt),
        ))
1796
1797
        assert sample_cnt == actual_sample_cnt.value
        return indices
1798

1799
1800
1801
1802
1803
    def _init_from_ref_dataset(
        self,
        total_nrow: int,
        ref_dataset: _DatasetHandle
    ) -> 'Dataset':
1804
1805
1806
1807
1808
1809
        """Create dataset from a reference dataset.

        Parameters
        ----------
        total_nrow : int
            Number of rows expected to add to dataset.
1810
1811
        ref_dataset : object
            Handle of reference dataset to extract metadata from.
1812
1813
1814
1815
1816
1817

        Returns
        -------
        self : Dataset
            Constructed Dataset object.
        """
1818
        self._handle = ctypes.c_void_p()
1819
1820
1821
        _safe_call(_LIB.LGBM_DatasetCreateByReference(
            ref_dataset,
            ctypes.c_int64(total_nrow),
1822
            ctypes.byref(self._handle),
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
        ))
        return self

    def _init_from_sample(
        self,
        sample_data: List[np.ndarray],
        sample_indices: List[np.ndarray],
        sample_cnt: int,
        total_nrow: int,
    ) -> "Dataset":
        """Create Dataset from sampled data structures.

        Parameters
        ----------
1837
        sample_data : list of numpy array
1838
            Sample data for each column.
1839
        sample_indices : list of numpy array
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
            Sample data row index for each column.
        sample_cnt : int
            Number of samples.
        total_nrow : int
            Total number of rows for all input files.

        Returns
        -------
        self : Dataset
            Constructed Dataset object.
        """
        ncol = len(sample_indices)
        assert len(sample_data) == ncol, "#sample data column != #column indices"

        for i in range(ncol):
            if sample_data[i].dtype != np.double:
                raise ValueError(f"sample_data[{i}] type {sample_data[i].dtype} is not double")
            if sample_indices[i].dtype != np.int32:
                raise ValueError(f"sample_indices[{i}] type {sample_indices[i].dtype} is not int32")

        # c type: double**
        # each double* element points to start of each column of sample data.
1862
        sample_col_ptr: _ctypes_float_array = (ctypes.POINTER(ctypes.c_double) * ncol)()
1863
1864
        # c type int**
        # each int* points to start of indices for each column
1865
        indices_col_ptr: _ctypes_int_array = (ctypes.POINTER(ctypes.c_int32) * ncol)()
1866
        for i in range(ncol):
1867
1868
            sample_col_ptr[i] = _c_float_array(sample_data[i])[0]
            indices_col_ptr[i] = _c_int_array(sample_indices[i])[0]
1869
1870

        num_per_col = np.array([len(d) for d in sample_indices], dtype=np.int32)
1871
        num_per_col_ptr, _, _ = _c_int_array(num_per_col)
1872

1873
        self._handle = ctypes.c_void_p()
1874
        params_str = _param_dict_to_str(self.get_params())
1875
1876
1877
1878
1879
1880
1881
        _safe_call(_LIB.LGBM_DatasetCreateFromSampledColumn(
            ctypes.cast(sample_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))),
            ctypes.cast(indices_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_int32))),
            ctypes.c_int32(ncol),
            num_per_col_ptr,
            ctypes.c_int32(sample_cnt),
            ctypes.c_int32(total_nrow),
1882
            ctypes.c_int64(total_nrow),
1883
            _c_str(params_str),
1884
            ctypes.byref(self._handle),
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
        ))
        return self

    def _push_rows(self, data: np.ndarray) -> 'Dataset':
        """Add rows to Dataset.

        Parameters
        ----------
        data : numpy 1-D array
            New data to add to the Dataset.

        Returns
        -------
        self : Dataset
            Dataset object.
        """
        nrow, ncol = data.shape
        data = data.reshape(data.size)
1903
        data_ptr, data_type, _ = _c_float_array(data)
1904
1905

        _safe_call(_LIB.LGBM_DatasetPushRows(
1906
            self._handle,
1907
1908
1909
1910
1911
1912
1913
1914
1915
            data_ptr,
            data_type,
            ctypes.c_int32(nrow),
            ctypes.c_int32(ncol),
            ctypes.c_int32(self._start_row),
        ))
        self._start_row += nrow
        return self

1916
    def get_params(self) -> Dict[str, Any]:
1917
1918
1919
1920
        """Get the used parameters in the Dataset.

        Returns
        -------
1921
        params : dict
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
            The used parameters in this Dataset object.
        """
        if self.params is not None:
            # no min_data, nthreads and verbose in this function
            dataset_params = _ConfigAliases.get("bin_construct_sample_cnt",
                                                "categorical_feature",
                                                "data_random_seed",
                                                "enable_bundle",
                                                "feature_pre_filter",
                                                "forcedbins_filename",
                                                "group_column",
                                                "header",
                                                "ignore_column",
                                                "is_enable_sparse",
                                                "label_column",
1937
                                                "linear_tree",
1938
1939
1940
1941
                                                "max_bin",
                                                "max_bin_by_feature",
                                                "min_data_in_bin",
                                                "pre_partition",
Nikita Titov's avatar
Nikita Titov committed
1942
                                                "precise_float_parser",
1943
1944
1945
1946
1947
                                                "two_round",
                                                "use_missing",
                                                "weight_column",
                                                "zero_as_missing")
            return {k: v for k, v in self.params.items() if k in dataset_params}
1948
1949
        else:
            return {}
1950

1951
    def _free_handle(self) -> "Dataset":
1952
1953
1954
        if self._handle is not None:
            _safe_call(_LIB.LGBM_DatasetFree(self._handle))
            self._handle = None
1955
        self._need_slice = True
Guolin Ke's avatar
Guolin Ke committed
1956
1957
        if self.used_indices is not None:
            self.data = None
Nikita Titov's avatar
Nikita Titov committed
1958
        return self
wxchan's avatar
wxchan committed
1959

1960
1961
1962
    def _set_init_score_by_predictor(
        self,
        predictor: Optional[_InnerPredictor],
1963
        data: _LGBM_TrainDataType,
1964
        used_indices: Optional[Union[List[int], np.ndarray]]
1965
    ) -> "Dataset":
Guolin Ke's avatar
Guolin Ke committed
1966
        data_has_header = False
1967
        if isinstance(data, (str, Path)) and self.params is not None:
Guolin Ke's avatar
Guolin Ke committed
1968
            # check data has header or not
1969
            data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header"))
Guolin Ke's avatar
Guolin Ke committed
1970
        num_data = self.num_data()
1971
        if predictor is not None:
1972
1973
1974
1975
1976
            init_score: Union[np.ndarray, scipy.sparse.spmatrix] = predictor.predict(
                data=data,
                raw_score=True,
                data_has_header=data_has_header
            )
1977
            init_score = init_score.ravel()
1978
            if used_indices is not None:
1979
                assert not self._need_slice
1980
                if isinstance(data, (str, Path)):
1981
                    sub_init_score = np.empty(num_data * predictor.num_class, dtype=np.float64)
1982
                    assert num_data == len(used_indices)
1983
1984
                    for i in range(len(used_indices)):
                        for j in range(predictor.num_class):
1985
1986
1987
1988
                            sub_init_score[i * predictor.num_class + j] = init_score[used_indices[i] * predictor.num_class + j]
                    init_score = sub_init_score
            if predictor.num_class > 1:
                # need to regroup init_score
1989
                new_init_score = np.empty(init_score.size, dtype=np.float64)
1990
1991
                for i in range(num_data):
                    for j in range(predictor.num_class):
1992
1993
1994
                        new_init_score[j * num_data + i] = init_score[i * predictor.num_class + j]
                init_score = new_init_score
        elif self.init_score is not None:
1995
            init_score = np.full_like(self.init_score, fill_value=0.0, dtype=np.float64)
1996
1997
        else:
            return self
Guolin Ke's avatar
Guolin Ke committed
1998
        self.set_init_score(init_score)
1999
        return self
Guolin Ke's avatar
Guolin Ke committed
2000

2001
2002
    def _lazy_init(
        self,
2003
        data: Optional[_LGBM_TrainDataType],
2004
2005
2006
2007
2008
2009
2010
2011
        label: Optional[_LGBM_LabelType],
        reference: Optional["Dataset"],
        weight: Optional[_LGBM_WeightType],
        group: Optional[_LGBM_GroupType],
        init_score: Optional[_LGBM_InitScoreType],
        predictor: Optional[_InnerPredictor],
        feature_name: _LGBM_FeatureNameConfiguration,
        categorical_feature: _LGBM_CategoricalFeatureConfiguration,
2012
2013
        params: Optional[Dict[str, Any]],
        position: Optional[_LGBM_PositionType]
2014
    ) -> "Dataset":
wxchan's avatar
wxchan committed
2015
        if data is None:
2016
            self._handle = None
Nikita Titov's avatar
Nikita Titov committed
2017
            return self
Guolin Ke's avatar
Guolin Ke committed
2018
2019
2020
        if reference is not None:
            self.pandas_categorical = reference.pandas_categorical
            categorical_feature = reference.categorical_feature
2021
2022
2023
2024
2025
2026
2027
        if isinstance(data, pd_DataFrame):
            data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas(
                data=data,
                feature_name=feature_name,
                categorical_feature=categorical_feature,
                pandas_categorical=self.pandas_categorical
            )
Guolin Ke's avatar
Guolin Ke committed
2028

2029
        # process for args
wxchan's avatar
wxchan committed
2030
        params = {} if params is None else params
2031
        args_names = inspect.signature(self.__class__._lazy_init).parameters.keys()
2032
        for key in params.keys():
2033
            if key in args_names:
2034
2035
                _log_warning(f'{key} keyword has been found in `params` and will be ignored.\n'
                             f'Please use {key} argument of the Dataset constructor to pass this parameter.')
2036
        # get categorical features
2037
        if isinstance(categorical_feature, list):
2038
2039
            categorical_indices = set()
            feature_dict = {}
2040
            if isinstance(feature_name, list):
2041
2042
                feature_dict = {name: i for i, name in enumerate(feature_name)}
            for name in categorical_feature:
2043
                if isinstance(name, str) and name in feature_dict:
2044
                    categorical_indices.add(feature_dict[name])
2045
                elif isinstance(name, int):
2046
2047
                    categorical_indices.add(name)
                else:
2048
                    raise TypeError(f"Wrong type({type(name).__name__}) or unknown name({name}) in categorical_feature")
2049
            if categorical_indices:
2050
2051
                for cat_alias in _ConfigAliases.get("categorical_feature"):
                    if cat_alias in params:
2052
                        # If the params[cat_alias] is equal to categorical_indices, do not report the warning.
2053
                        if not (isinstance(params[cat_alias], list) and set(params[cat_alias]) == categorical_indices):
2054
                            _log_warning(f'{cat_alias} in param dict is overridden.')
2055
                        params.pop(cat_alias, None)
2056
                params['categorical_column'] = sorted(categorical_indices)
2057

2058
        params_str = _param_dict_to_str(params)
2059
        self.params = params
2060
        # process for reference dataset
wxchan's avatar
wxchan committed
2061
        ref_dataset = None
wxchan's avatar
wxchan committed
2062
        if isinstance(reference, Dataset):
2063
            ref_dataset = reference.construct()._handle
wxchan's avatar
wxchan committed
2064
2065
        elif reference is not None:
            raise TypeError('Reference dataset should be None or dataset instance')
2066
        # start construct data
2067
        if isinstance(data, (str, Path)):
2068
            self._handle = ctypes.c_void_p()
wxchan's avatar
wxchan committed
2069
            _safe_call(_LIB.LGBM_DatasetCreateFromFile(
2070
2071
                _c_str(str(data)),
                _c_str(params_str),
wxchan's avatar
wxchan committed
2072
                ref_dataset,
2073
                ctypes.byref(self._handle)))
wxchan's avatar
wxchan committed
2074
2075
        elif isinstance(data, scipy.sparse.csr_matrix):
            self.__init_from_csr(data, params_str, ref_dataset)
Guolin Ke's avatar
Guolin Ke committed
2076
2077
        elif isinstance(data, scipy.sparse.csc_matrix):
            self.__init_from_csc(data, params_str, ref_dataset)
wxchan's avatar
wxchan committed
2078
2079
        elif isinstance(data, np.ndarray):
            self.__init_from_np2d(data, params_str, ref_dataset)
2080
2081
2082
        elif _is_pyarrow_table(data):
            self.__init_from_pyarrow_table(data, params_str, ref_dataset)
            feature_name = data.column_names
2083
        elif isinstance(data, list) and len(data) > 0:
2084
            if _is_list_of_numpy_arrays(data):
2085
                self.__init_from_list_np2d(data, params_str, ref_dataset)
2086
            elif _is_list_of_sequences(data):
2087
2088
2089
2090
2091
                self.__init_from_seqs(data, ref_dataset)
            else:
                raise TypeError('Data list can only be of ndarray or Sequence')
        elif isinstance(data, Sequence):
            self.__init_from_seqs([data], ref_dataset)
2092
        elif isinstance(data, dt_DataTable):
2093
            self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset)
wxchan's avatar
wxchan committed
2094
2095
2096
2097
        else:
            try:
                csr = scipy.sparse.csr_matrix(data)
                self.__init_from_csr(csr, params_str, ref_dataset)
2098
2099
            except BaseException as err:
                raise TypeError(f'Cannot initialize Dataset from {type(data).__name__}') from err
wxchan's avatar
wxchan committed
2100
2101
2102
        if label is not None:
            self.set_label(label)
        if self.get_label() is None:
2103
            raise ValueError("Label should not be None")
wxchan's avatar
wxchan committed
2104
2105
2106
2107
        if weight is not None:
            self.set_weight(weight)
        if group is not None:
            self.set_group(group)
2108
2109
        if position is not None:
            self.set_position(position)
2110
2111
        if isinstance(predictor, _InnerPredictor):
            if self._predictor is None and init_score is not None:
2112
                _log_warning("The init_score will be overridden by the prediction of init_model.")
2113
2114
2115
2116
2117
            self._set_init_score_by_predictor(
                predictor=predictor,
                data=data,
                used_indices=None
            )
2118
2119
        elif init_score is not None:
            self.set_init_score(init_score)
Guolin Ke's avatar
Guolin Ke committed
2120
        elif predictor is not None:
2121
            raise TypeError(f'Wrong predictor type {type(predictor).__name__}')
Guolin Ke's avatar
Guolin Ke committed
2122
        # set feature names
Nikita Titov's avatar
Nikita Titov committed
2123
        return self.set_feature_name(feature_name)
wxchan's avatar
wxchan committed
2124

2125
2126
    @staticmethod
    def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]):
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
        offset = 0
        seq_id = 0
        seq = seqs[seq_id]
        for row_id in indices:
            assert row_id >= offset, "sample indices are expected to be monotonic"
            while row_id >= offset + len(seq):
                offset += len(seq)
                seq_id += 1
                seq = seqs[seq_id]
            id_in_seq = row_id - offset
            row = seq[id_in_seq]
            yield row if row.flags['OWNDATA'] else row.copy()

    def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """Sample data from seqs.

        Mimics behavior in c_api.cpp:LGBM_DatasetCreateFromMats()

        Returns
        -------
            sampled_rows, sampled_row_indices
        """
        indices = self._create_sample_indices(total_nrow)

        # Select sampled rows, transpose to column order.
2152
        sampled = np.array(list(self._yield_row_from_seqlist(seqs, indices)))
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
        sampled = sampled.T

        filtered = []
        filtered_idx = []
        sampled_row_range = np.arange(len(indices), dtype=np.int32)
        for col in sampled:
            col_predicate = (np.abs(col) > ZERO_THRESHOLD) | np.isnan(col)
            filtered_col = col[col_predicate]
            filtered_row_idx = sampled_row_range[col_predicate]

            filtered.append(filtered_col)
            filtered_idx.append(filtered_row_idx)

        return filtered, filtered_idx

2168
2169
2170
    def __init_from_seqs(
        self,
        seqs: List[Sequence],
2171
        ref_dataset: Optional[_DatasetHandle]
2172
    ) -> "Dataset":
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
        """
        Initialize data from list of Sequence objects.

        Sequence: Generic Data Access Object
            Supports random access and access by batch if properly defined by user

        Data scheme uniformity are trusted, not checked
        """
        total_nrow = sum(len(seq) for seq in seqs)

        # create validation dataset from ref_dataset
        if ref_dataset is not None:
            self._init_from_ref_dataset(total_nrow, ref_dataset)
        else:
2187
            param_str = _param_dict_to_str(self.get_params())
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
            sample_cnt = _get_sample_count(total_nrow, param_str)

            sample_data, col_indices = self.__sample(seqs, total_nrow)
            self._init_from_sample(sample_data, col_indices, sample_cnt, total_nrow)

        for seq in seqs:
            nrow = len(seq)
            batch_size = getattr(seq, 'batch_size', None) or Sequence.batch_size
            for start in range(0, nrow, batch_size):
                end = min(start + batch_size, nrow)
                self._push_rows(seq[start:end])
        return self

2201
2202
2203
2204
2205
2206
    def __init_from_np2d(
        self,
        mat: np.ndarray,
        params_str: str,
        ref_dataset: Optional[_DatasetHandle]
    ) -> "Dataset":
2207
        """Initialize data from a 2-D numpy matrix."""
wxchan's avatar
wxchan committed
2208
2209
2210
        if len(mat.shape) != 2:
            raise ValueError('Input numpy.ndarray must be 2 dimensional')

2211
        self._handle = ctypes.c_void_p()
wxchan's avatar
wxchan committed
2212
2213
        if mat.dtype == np.float32 or mat.dtype == np.float64:
            data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False)
2214
        else:  # change non-float data to float data, need to copy
wxchan's avatar
wxchan committed
2215
2216
            data = np.array(mat.reshape(mat.size), dtype=np.float32)

2217
        ptr_data, type_ptr_data, _ = _c_float_array(data)
wxchan's avatar
wxchan committed
2218
2219
        _safe_call(_LIB.LGBM_DatasetCreateFromMat(
            ptr_data,
Guolin Ke's avatar
Guolin Ke committed
2220
            ctypes.c_int(type_ptr_data),
2221
2222
            ctypes.c_int32(mat.shape[0]),
            ctypes.c_int32(mat.shape[1]),
2223
            ctypes.c_int(_C_API_IS_ROW_MAJOR),
2224
            _c_str(params_str),
wxchan's avatar
wxchan committed
2225
            ref_dataset,
2226
            ctypes.byref(self._handle)))
Nikita Titov's avatar
Nikita Titov committed
2227
        return self
wxchan's avatar
wxchan committed
2228

2229
2230
2231
2232
2233
2234
    def __init_from_list_np2d(
        self,
        mats: List[np.ndarray],
        params_str: str,
        ref_dataset: Optional[_DatasetHandle]
    ) -> "Dataset":
2235
        """Initialize data from a list of 2-D numpy matrices."""
2236
        ncol = mats[0].shape[1]
2237
        nrow = np.empty((len(mats),), np.int32)
2238
        ptr_data: _ctypes_float_array
2239
2240
2241
2242
2243
2244
        if mats[0].dtype == np.float64:
            ptr_data = (ctypes.POINTER(ctypes.c_double) * len(mats))()
        else:
            ptr_data = (ctypes.POINTER(ctypes.c_float) * len(mats))()

        holders = []
2245
        type_ptr_data = -1
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257

        for i, mat in enumerate(mats):
            if len(mat.shape) != 2:
                raise ValueError('Input numpy.ndarray must be 2 dimensional')

            if mat.shape[1] != ncol:
                raise ValueError('Input arrays must have same number of columns')

            nrow[i] = mat.shape[0]

            if mat.dtype == np.float32 or mat.dtype == np.float64:
                mats[i] = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False)
2258
            else:  # change non-float data to float data, need to copy
2259
2260
                mats[i] = np.array(mat.reshape(mat.size), dtype=np.float32)

2261
            chunk_ptr_data, chunk_type_ptr_data, holder = _c_float_array(mats[i])
2262
            if type_ptr_data != -1 and chunk_type_ptr_data != type_ptr_data:
2263
2264
2265
2266
2267
                raise ValueError('Input chunks must have same type')
            ptr_data[i] = chunk_ptr_data
            type_ptr_data = chunk_type_ptr_data
            holders.append(holder)

2268
        self._handle = ctypes.c_void_p()
2269
        _safe_call(_LIB.LGBM_DatasetCreateFromMats(
2270
            ctypes.c_int32(len(mats)),
2271
2272
2273
            ctypes.cast(ptr_data, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))),
            ctypes.c_int(type_ptr_data),
            nrow.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
2274
            ctypes.c_int32(ncol),
2275
            ctypes.c_int(_C_API_IS_ROW_MAJOR),
2276
            _c_str(params_str),
2277
            ref_dataset,
2278
            ctypes.byref(self._handle)))
Nikita Titov's avatar
Nikita Titov committed
2279
        return self
2280

2281
2282
2283
2284
2285
2286
    def __init_from_csr(
        self,
        csr: scipy.sparse.csr_matrix,
        params_str: str,
        ref_dataset: Optional[_DatasetHandle]
    ) -> "Dataset":
2287
        """Initialize data from a CSR matrix."""
wxchan's avatar
wxchan committed
2288
        if len(csr.indices) != len(csr.data):
2289
            raise ValueError(f'Length mismatch: {len(csr.indices)} vs {len(csr.data)}')
2290
        self._handle = ctypes.c_void_p()
wxchan's avatar
wxchan committed
2291

2292
2293
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
wxchan's avatar
wxchan committed
2294

2295
        assert csr.shape[1] <= _MAX_INT32
2296
        csr_indices = csr.indices.astype(np.int32, copy=False)
2297

wxchan's avatar
wxchan committed
2298
2299
        _safe_call(_LIB.LGBM_DatasetCreateFromCSR(
            ptr_indptr,
Guolin Ke's avatar
Guolin Ke committed
2300
            ctypes.c_int(type_ptr_indptr),
2301
            csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
wxchan's avatar
wxchan committed
2302
            ptr_data,
Guolin Ke's avatar
Guolin Ke committed
2303
2304
2305
2306
            ctypes.c_int(type_ptr_data),
            ctypes.c_int64(len(csr.indptr)),
            ctypes.c_int64(len(csr.data)),
            ctypes.c_int64(csr.shape[1]),
2307
            _c_str(params_str),
wxchan's avatar
wxchan committed
2308
            ref_dataset,
2309
            ctypes.byref(self._handle)))
Nikita Titov's avatar
Nikita Titov committed
2310
        return self
wxchan's avatar
wxchan committed
2311

2312
2313
2314
2315
2316
2317
    def __init_from_csc(
        self,
        csc: scipy.sparse.csc_matrix,
        params_str: str,
        ref_dataset: Optional[_DatasetHandle]
    ) -> "Dataset":
2318
        """Initialize data from a CSC matrix."""
Guolin Ke's avatar
Guolin Ke committed
2319
        if len(csc.indices) != len(csc.data):
2320
            raise ValueError(f'Length mismatch: {len(csc.indices)} vs {len(csc.data)}')
2321
        self._handle = ctypes.c_void_p()
Guolin Ke's avatar
Guolin Ke committed
2322

2323
2324
        ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
        ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
Guolin Ke's avatar
Guolin Ke committed
2325

2326
        assert csc.shape[0] <= _MAX_INT32
2327
        csc_indices = csc.indices.astype(np.int32, copy=False)
2328

Guolin Ke's avatar
Guolin Ke committed
2329
2330
        _safe_call(_LIB.LGBM_DatasetCreateFromCSC(
            ptr_indptr,
Guolin Ke's avatar
Guolin Ke committed
2331
            ctypes.c_int(type_ptr_indptr),
2332
            csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
Guolin Ke's avatar
Guolin Ke committed
2333
            ptr_data,
Guolin Ke's avatar
Guolin Ke committed
2334
2335
2336
2337
            ctypes.c_int(type_ptr_data),
            ctypes.c_int64(len(csc.indptr)),
            ctypes.c_int64(len(csc.data)),
            ctypes.c_int64(csc.shape[0]),
2338
            _c_str(params_str),
Guolin Ke's avatar
Guolin Ke committed
2339
            ref_dataset,
2340
            ctypes.byref(self._handle)))
Nikita Titov's avatar
Nikita Titov committed
2341
        return self
Guolin Ke's avatar
Guolin Ke committed
2342

2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
    def __init_from_pyarrow_table(
        self,
        table: pa_Table,
        params_str: str,
        ref_dataset: Optional[_DatasetHandle]
    ) -> "Dataset":
        """Initialize data from a PyArrow table."""
        if not PYARROW_INSTALLED:
            raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.")

        # Check that the input is valid: we only handle numbers (for now)
        if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
            raise ValueError("Arrow table may only have integer or floating point datatypes")

        # Export Arrow table to C
        c_array = _export_arrow_to_c(table)
        self._handle = ctypes.c_void_p()
        _safe_call(_LIB.LGBM_DatasetCreateFromArrow(
            ctypes.c_int64(c_array.n_chunks),
            ctypes.c_void_p(c_array.chunks_ptr),
            ctypes.c_void_p(c_array.schema_ptr),
            _c_str(params_str),
            ref_dataset,
            ctypes.byref(self._handle)))
        return self

2369
    @staticmethod
2370
    def _compare_params_for_warning(
2371
2372
        params: Dict[str, Any],
        other_params: Dict[str, Any],
2373
2374
2375
        ignore_keys: Set[str]
    ) -> bool:
        """Compare two dictionaries with params ignoring some keys.
2376

2377
2378
2379
2380
        It is only for the warning purpose.

        Parameters
        ----------
2381
        params : dict
2382
            One dictionary with parameters to compare.
2383
        other_params : dict
2384
2385
2386
            Another dictionary with parameters to compare.
        ignore_keys : set
            Keys that should be ignored during comparing two dictionaries.
2387
2388
2389

        Returns
        -------
2390
2391
        compare_result : bool
          Returns whether two dictionaries with params are equal.
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
        """
        for k in other_params:
            if k not in ignore_keys:
                if k not in params or params[k] != other_params[k]:
                    return False
        for k in params:
            if k not in ignore_keys:
                if k not in other_params or params[k] != other_params[k]:
                    return False
        return True

2403
    def construct(self) -> "Dataset":
2404
2405
2406
2407
2408
        """Lazy init.

        Returns
        -------
        self : Dataset
Nikita Titov's avatar
Nikita Titov committed
2409
            Constructed Dataset object.
2410
        """
2411
        if self._handle is None:
wxchan's avatar
wxchan committed
2412
            if self.reference is not None:
2413
                reference_params = self.reference.get_params()
2414
2415
                params = self.get_params()
                if params != reference_params:
2416
2417
2418
2419
2420
                    if not self._compare_params_for_warning(
                        params=params,
                        other_params=reference_params,
                        ignore_keys=_ConfigAliases.get("categorical_feature")
                    ):
2421
                        _log_warning('Overriding the parameters from Reference Dataset.')
2422
                    self._update_params(reference_params)
wxchan's avatar
wxchan committed
2423
                if self.used_indices is None:
2424
                    # create valid
2425
                    self._lazy_init(data=self.data, label=self.label, reference=self.reference,
2426
                                    weight=self.weight, group=self.group, position=self.position,
2427
                                    init_score=self.init_score, predictor=self._predictor,
2428
                                    feature_name=self.feature_name, categorical_feature='auto', params=self.params)
wxchan's avatar
wxchan committed
2429
                else:
2430
                    # construct subset
2431
                    used_indices = _list_to_1d_numpy(self.used_indices, dtype=np.int32, name='used_indices')
2432
                    assert used_indices.flags.c_contiguous
Guolin Ke's avatar
Guolin Ke committed
2433
                    if self.reference.group is not None:
2434
                        group_info = np.array(self.reference.group).astype(np.int32, copy=False)
2435
                        _, self.group = np.unique(np.repeat(range(len(group_info)), repeats=group_info)[self.used_indices],
2436
                                                  return_counts=True)
2437
                    self._handle = ctypes.c_void_p()
2438
                    params_str = _param_dict_to_str(self.params)
wxchan's avatar
wxchan committed
2439
                    _safe_call(_LIB.LGBM_DatasetGetSubset(
2440
                        self.reference.construct()._handle,
wxchan's avatar
wxchan committed
2441
                        used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
2442
                        ctypes.c_int32(used_indices.shape[0]),
2443
                        _c_str(params_str),
2444
                        ctypes.byref(self._handle)))
Guolin Ke's avatar
Guolin Ke committed
2445
2446
                    if not self.free_raw_data:
                        self.get_data()
Guolin Ke's avatar
Guolin Ke committed
2447
2448
                    if self.group is not None:
                        self.set_group(self.group)
2449
2450
                    if self.position is not None:
                        self.set_position(self.position)
wxchan's avatar
wxchan committed
2451
2452
                    if self.get_label() is None:
                        raise ValueError("Label should not be None.")
Guolin Ke's avatar
Guolin Ke committed
2453
2454
                    if isinstance(self._predictor, _InnerPredictor) and self._predictor is not self.reference._predictor:
                        self.get_data()
2455
2456
2457
2458
2459
                        self._set_init_score_by_predictor(
                            predictor=self._predictor,
                            data=self.data,
                            used_indices=used_indices
                        )
wxchan's avatar
wxchan committed
2460
            else:
2461
                # create train
2462
                self._lazy_init(data=self.data, label=self.label, reference=None,
2463
2464
                                weight=self.weight, group=self.group,
                                init_score=self.init_score, predictor=self._predictor,
2465
2466
                                feature_name=self.feature_name, categorical_feature=self.categorical_feature,
                                params=self.params, position=self.position)
wxchan's avatar
wxchan committed
2467
2468
            if self.free_raw_data:
                self.data = None
2469
            self.feature_name = self.get_feature_name()
wxchan's avatar
wxchan committed
2470
        return self
wxchan's avatar
wxchan committed
2471

2472
2473
    def create_valid(
        self,
2474
        data: _LGBM_TrainDataType,
2475
        label: Optional[_LGBM_LabelType] = None,
2476
2477
2478
        weight: Optional[_LGBM_WeightType] = None,
        group: Optional[_LGBM_GroupType] = None,
        init_score: Optional[_LGBM_InitScoreType] = None,
2479
2480
        params: Optional[Dict[str, Any]] = None,
        position: Optional[_LGBM_PositionType] = None
2481
    ) -> "Dataset":
2482
        """Create validation data align with current Dataset.
wxchan's avatar
wxchan committed
2483
2484
2485

        Parameters
        ----------
2486
        data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
wxchan's avatar
wxchan committed
2487
            Data source of Dataset.
2488
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
2489
        label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
2490
            Label of the data.
2491
        weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
2492
            Weight for each instance. Weights should be non-negative.
2493
        group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
2494
2495
2496
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
2497
2498
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
2499
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
2500
            Init score for Dataset.
Nikita Titov's avatar
Nikita Titov committed
2501
        params : dict or None, optional (default=None)
2502
            Other parameters for validation Dataset.
2503
2504
        position : numpy 1-D array, pandas Series or None, optional (default=None)
            Position of items used in unbiased learning-to-rank task.
2505
2506
2507

        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
2508
2509
        valid : Dataset
            Validation Dataset with reference to self.
wxchan's avatar
wxchan committed
2510
        """
2511
        ret = Dataset(data, label=label, reference=self,
2512
                      weight=weight, group=group, position=position, init_score=init_score,
2513
                      params=params, free_raw_data=self.free_raw_data)
wxchan's avatar
wxchan committed
2514
        ret._predictor = self._predictor
2515
        ret.pandas_categorical = self.pandas_categorical
wxchan's avatar
wxchan committed
2516
        return ret
wxchan's avatar
wxchan committed
2517

2518
2519
2520
2521
2522
    def subset(
        self,
        used_indices: List[int],
        params: Optional[Dict[str, Any]] = None
    ) -> "Dataset":
2523
        """Get subset of current Dataset.
wxchan's avatar
wxchan committed
2524
2525
2526
2527

        Parameters
        ----------
        used_indices : list of int
2528
            Indices used to create the subset.
Nikita Titov's avatar
Nikita Titov committed
2529
        params : dict or None, optional (default=None)
2530
            These parameters will be passed to Dataset constructor.
2531
2532
2533
2534
2535

        Returns
        -------
        subset : Dataset
            Subset of the current Dataset.
wxchan's avatar
wxchan committed
2536
        """
wxchan's avatar
wxchan committed
2537
2538
        if params is None:
            params = self.params
wxchan's avatar
wxchan committed
2539
        ret = Dataset(None, reference=self, feature_name=self.feature_name,
2540
2541
                      categorical_feature=self.categorical_feature, params=params,
                      free_raw_data=self.free_raw_data)
wxchan's avatar
wxchan committed
2542
        ret._predictor = self._predictor
2543
        ret.pandas_categorical = self.pandas_categorical
2544
        ret.used_indices = sorted(used_indices)
wxchan's avatar
wxchan committed
2545
2546
        return ret

2547
    def save_binary(self, filename: Union[str, Path]) -> "Dataset":
2548
        """Save Dataset to a binary file.
wxchan's avatar
wxchan committed
2549

2550
2551
2552
2553
2554
        .. note::

            Please note that `init_score` is not saved in binary file.
            If you need it, please set it again after loading Dataset.

wxchan's avatar
wxchan committed
2555
2556
        Parameters
        ----------
2557
        filename : str or pathlib.Path
wxchan's avatar
wxchan committed
2558
            Name of the output file.
Nikita Titov's avatar
Nikita Titov committed
2559
2560
2561
2562
2563

        Returns
        -------
        self : Dataset
            Returns self.
wxchan's avatar
wxchan committed
2564
2565
        """
        _safe_call(_LIB.LGBM_DatasetSaveBinary(
2566
            self.construct()._handle,
2567
            _c_str(str(filename))))
Nikita Titov's avatar
Nikita Titov committed
2568
        return self
wxchan's avatar
wxchan committed
2569

2570
    def _update_params(self, params: Optional[Dict[str, Any]]) -> "Dataset":
2571
2572
        if not params:
            return self
2573
        params = deepcopy(params)
2574
2575
2576
2577
2578

        def update():
            if not self.params:
                self.params = params
            else:
2579
                self._params_back_up = deepcopy(self.params)
2580
2581
                self.params.update(params)

2582
        if self._handle is None:
2583
2584
2585
            update()
        elif params is not None:
            ret = _LIB.LGBM_DatasetUpdateParamChecking(
2586
2587
                _c_str(_param_dict_to_str(self.params)),
                _c_str(_param_dict_to_str(params)))
2588
2589
2590
2591
2592
2593
            if ret != 0:
                # could be updated if data is not freed
                if self.data is not None:
                    update()
                    self._free_handle()
                else:
2594
                    raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8'))
Nikita Titov's avatar
Nikita Titov committed
2595
        return self
wxchan's avatar
wxchan committed
2596

2597
    def _reverse_update_params(self) -> "Dataset":
2598
        if self._handle is None:
2599
2600
            self.params = deepcopy(self._params_back_up)
            self._params_back_up = None
Nikita Titov's avatar
Nikita Titov committed
2601
        return self
2602

2603
2604
2605
    def set_field(
        self,
        field_name: str,
2606
        data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Table, pa_Array, pa_ChunkedArray]]
2607
    ) -> "Dataset":
wxchan's avatar
wxchan committed
2608
        """Set property into the Dataset.
wxchan's avatar
wxchan committed
2609
2610
2611

        Parameters
        ----------
2612
        field_name : str
2613
            The field name of the information.
2614
        data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray or None
2615
            The data to be set.
Nikita Titov's avatar
Nikita Titov committed
2616
2617
2618
2619
2620

        Returns
        -------
        self : Dataset
            Dataset with set property.
wxchan's avatar
wxchan committed
2621
        """
2622
        if self._handle is None:
2623
            raise Exception(f"Cannot set {field_name} before construct dataset")
wxchan's avatar
wxchan committed
2624
        if data is None:
2625
            # set to None
wxchan's avatar
wxchan committed
2626
            _safe_call(_LIB.LGBM_DatasetSetField(
2627
                self._handle,
2628
                _c_str(field_name),
wxchan's avatar
wxchan committed
2629
                None,
Guolin Ke's avatar
Guolin Ke committed
2630
                ctypes.c_int(0),
2631
                ctypes.c_int(_FIELD_TYPE_MAPPER[field_name])))
Nikita Titov's avatar
Nikita Titov committed
2632
            return self
2633
2634

        # If the data is a arrow data, we can just pass it to C
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
        if _is_pyarrow_array(data) or _is_pyarrow_table(data):
            # If a table is being passed, we concatenate the columns. This is only valid for
            # 'init_score'.
            if _is_pyarrow_table(data):
                if field_name != "init_score":
                    raise ValueError(f"pyarrow tables are not supported for field '{field_name}'")
                data = pa_chunked_array([
                    chunk for array in data.columns for chunk in array.chunks  # type: ignore
                ])

2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
            c_array = _export_arrow_to_c(data)
            _safe_call(_LIB.LGBM_DatasetSetFieldFromArrow(
                self._handle,
                _c_str(field_name),
                ctypes.c_int64(c_array.n_chunks),
                ctypes.c_void_p(c_array.chunks_ptr),
                ctypes.c_void_p(c_array.schema_ptr),
            ))
            self.version += 1
            return self

2656
        dtype: "np.typing.DTypeLike"
2657
        if field_name == 'init_score':
Guolin Ke's avatar
Guolin Ke committed
2658
            dtype = np.float64
2659
            if _is_1d_collection(data):
2660
                data = _list_to_1d_numpy(data, dtype=dtype, name=field_name)
2661
            elif _is_2d_collection(data):
2662
                data = _data_to_2d_numpy(data, dtype=dtype, name=field_name)
2663
2664
2665
2666
2667
2668
2669
                data = data.ravel(order='F')
            else:
                raise TypeError(
                    'init_score must be list, numpy 1-D array or pandas Series.\n'
                    'In multiclass classification init_score can also be a list of lists, numpy 2-D array or pandas DataFrame.'
                )
        else:
2670
            dtype = np.int32 if (field_name == 'group' or field_name == 'position') else np.float32
2671
            data = _list_to_1d_numpy(data, dtype=dtype, name=field_name)
2672

2673
        ptr_data: Union[_ctypes_float_ptr, _ctypes_int_ptr]
2674
        if data.dtype == np.float32 or data.dtype == np.float64:
2675
            ptr_data, type_data, _ = _c_float_array(data)
wxchan's avatar
wxchan committed
2676
        elif data.dtype == np.int32:
2677
            ptr_data, type_data, _ = _c_int_array(data)
wxchan's avatar
wxchan committed
2678
        else:
2679
            raise TypeError(f"Expected np.float32/64 or np.int32, met type({data.dtype})")
2680
        if type_data != _FIELD_TYPE_MAPPER[field_name]:
2681
            raise TypeError("Input type error for set_field")
wxchan's avatar
wxchan committed
2682
        _safe_call(_LIB.LGBM_DatasetSetField(
2683
            self._handle,
2684
            _c_str(field_name),
wxchan's avatar
wxchan committed
2685
            ptr_data,
Guolin Ke's avatar
Guolin Ke committed
2686
2687
            ctypes.c_int(len(data)),
            ctypes.c_int(type_data)))
2688
        self.version += 1
Nikita Titov's avatar
Nikita Titov committed
2689
        return self
wxchan's avatar
wxchan committed
2690

2691
    def get_field(self, field_name: str) -> Optional[np.ndarray]:
wxchan's avatar
wxchan committed
2692
        """Get property from the Dataset.
wxchan's avatar
wxchan committed
2693

2694
2695
2696
2697
2698
2699
        Can only be run on a constructed Dataset.

        Unlike ``get_group()``, ``get_init_score()``, ``get_label()``, ``get_position()``, and ``get_weight()``,
        this method ignores any raw data passed into ``lgb.Dataset()`` on the Python side, and will only read
        data from the constructed C++ ``Dataset`` object.

wxchan's avatar
wxchan committed
2700
2701
        Parameters
        ----------
2702
        field_name : str
2703
            The field name of the information.
wxchan's avatar
wxchan committed
2704
2705
2706

        Returns
        -------
2707
        info : numpy array or None
2708
            A numpy array with information from the Dataset.
Guolin Ke's avatar
Guolin Ke committed
2709
        """
2710
        if self._handle is None:
2711
            raise Exception(f"Cannot get {field_name} before construct Dataset")
2712
2713
        tmp_out_len = ctypes.c_int(0)
        out_type = ctypes.c_int(0)
wxchan's avatar
wxchan committed
2714
2715
        ret = ctypes.POINTER(ctypes.c_void_p)()
        _safe_call(_LIB.LGBM_DatasetGetField(
2716
            self._handle,
2717
            _c_str(field_name),
wxchan's avatar
wxchan committed
2718
2719
2720
            ctypes.byref(tmp_out_len),
            ctypes.byref(ret),
            ctypes.byref(out_type)))
2721
        if out_type.value != _FIELD_TYPE_MAPPER[field_name]:
wxchan's avatar
wxchan committed
2722
2723
2724
            raise TypeError("Return type error for get_field")
        if tmp_out_len.value == 0:
            return None
2725
        if out_type.value == _C_API_DTYPE_INT32:
2726
2727
2728
2729
            arr = _cint32_array_to_numpy(
                cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)),
                length=tmp_out_len.value
            )
2730
        elif out_type.value == _C_API_DTYPE_FLOAT32:
2731
2732
2733
2734
            arr = _cfloat32_array_to_numpy(
                cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)),
                length=tmp_out_len.value
            )
2735
        elif out_type.value == _C_API_DTYPE_FLOAT64:
2736
2737
2738
2739
            arr = _cfloat64_array_to_numpy(
                cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)),
                length=tmp_out_len.value
            )
2740
        else:
wxchan's avatar
wxchan committed
2741
            raise TypeError("Unknown type")
2742
2743
2744
2745
2746
2747
        if field_name == 'init_score':
            num_data = self.num_data()
            num_classes = arr.size // num_data
            if num_classes > 1:
                arr = arr.reshape((num_data, num_classes), order='F')
        return arr
Guolin Ke's avatar
Guolin Ke committed
2748

2749
2750
    def set_categorical_feature(
        self,
2751
        categorical_feature: _LGBM_CategoricalFeatureConfiguration
2752
    ) -> "Dataset":
2753
        """Set categorical features.
2754
2755
2756

        Parameters
        ----------
2757
        categorical_feature : list of str or int, or 'auto'
2758
            Names or indices of categorical features.
Nikita Titov's avatar
Nikita Titov committed
2759
2760
2761
2762
2763

        Returns
        -------
        self : Dataset
            Dataset with set categorical features.
2764
2765
        """
        if self.categorical_feature == categorical_feature:
Nikita Titov's avatar
Nikita Titov committed
2766
            return self
2767
        if self.data is not None:
2768
2769
            if self.categorical_feature is None:
                self.categorical_feature = categorical_feature
Nikita Titov's avatar
Nikita Titov committed
2770
                return self._free_handle()
2771
            elif categorical_feature == 'auto':
Nikita Titov's avatar
Nikita Titov committed
2772
                return self
2773
            else:
2774
2775
                if self.categorical_feature != 'auto':
                    _log_warning('categorical_feature in Dataset is overridden.\n'
2776
                                 f'New categorical_feature is {list(categorical_feature)}')
2777
                self.categorical_feature = categorical_feature
Nikita Titov's avatar
Nikita Titov committed
2778
                return self._free_handle()
2779
        else:
2780
2781
            raise LightGBMError("Cannot set categorical feature after freed raw data, "
                                "set free_raw_data=False when construct Dataset to avoid this.")
2782

2783
2784
2785
2786
    def _set_predictor(
        self,
        predictor: Optional[_InnerPredictor]
    ) -> "Dataset":
2787
2788
2789
2790
        """Set predictor for continued training.

        It is not recommended for user to call this function.
        Please use init_model argument in engine.train() or engine.cv() instead.
Guolin Ke's avatar
Guolin Ke committed
2791
        """
2792
        if predictor is None and self._predictor is None:
Nikita Titov's avatar
Nikita Titov committed
2793
            return self
2794
2795
2796
        elif isinstance(predictor, _InnerPredictor) and isinstance(self._predictor, _InnerPredictor):
            if (predictor == self._predictor) and (predictor.current_iteration() == self._predictor.current_iteration()):
                return self
2797
        if self._handle is None:
Guolin Ke's avatar
Guolin Ke committed
2798
            self._predictor = predictor
2799
2800
        elif self.data is not None:
            self._predictor = predictor
2801
2802
2803
2804
2805
            self._set_init_score_by_predictor(
                predictor=self._predictor,
                data=self.data,
                used_indices=None
            )
2806
2807
        elif self.used_indices is not None and self.reference is not None and self.reference.data is not None:
            self._predictor = predictor
2808
2809
2810
2811
2812
            self._set_init_score_by_predictor(
                predictor=self._predictor,
                data=self.reference.data,
                used_indices=self.used_indices
            )
Guolin Ke's avatar
Guolin Ke committed
2813
        else:
2814
2815
            raise LightGBMError("Cannot set predictor after freed raw data, "
                                "set free_raw_data=False when construct Dataset to avoid this.")
2816
        return self
Guolin Ke's avatar
Guolin Ke committed
2817

2818
    def set_reference(self, reference: "Dataset") -> "Dataset":
2819
        """Set reference Dataset.
Guolin Ke's avatar
Guolin Ke committed
2820
2821
2822
2823

        Parameters
        ----------
        reference : Dataset
2824
            Reference that is used as a template to construct the current Dataset.
Nikita Titov's avatar
Nikita Titov committed
2825
2826
2827
2828
2829

        Returns
        -------
        self : Dataset
            Dataset with set reference.
Guolin Ke's avatar
Guolin Ke committed
2830
        """
2831
2832
2833
        self.set_categorical_feature(reference.categorical_feature) \
            .set_feature_name(reference.feature_name) \
            ._set_predictor(reference._predictor)
2834
        # we're done if self and reference share a common upstream reference
2835
        if self.get_ref_chain().intersection(reference.get_ref_chain()):
Nikita Titov's avatar
Nikita Titov committed
2836
            return self
Guolin Ke's avatar
Guolin Ke committed
2837
2838
        if self.data is not None:
            self.reference = reference
Nikita Titov's avatar
Nikita Titov committed
2839
            return self._free_handle()
Guolin Ke's avatar
Guolin Ke committed
2840
        else:
2841
2842
            raise LightGBMError("Cannot set reference after freed raw data, "
                                "set free_raw_data=False when construct Dataset to avoid this.")
Guolin Ke's avatar
Guolin Ke committed
2843

2844
    def set_feature_name(self, feature_name: _LGBM_FeatureNameConfiguration) -> "Dataset":
2845
        """Set feature name.
Guolin Ke's avatar
Guolin Ke committed
2846
2847
2848

        Parameters
        ----------
2849
        feature_name : list of str
2850
            Feature names.
Nikita Titov's avatar
Nikita Titov committed
2851
2852
2853
2854
2855

        Returns
        -------
        self : Dataset
            Dataset with set feature name.
Guolin Ke's avatar
Guolin Ke committed
2856
        """
2857
2858
        if feature_name != 'auto':
            self.feature_name = feature_name
2859
        if self._handle is not None and feature_name is not None and feature_name != 'auto':
wxchan's avatar
wxchan committed
2860
            if len(feature_name) != self.num_feature():
2861
                raise ValueError(f"Length of feature_name({len(feature_name)}) and num_feature({self.num_feature()}) don't match")
2862
            c_feature_name = [_c_str(name) for name in feature_name]
wxchan's avatar
wxchan committed
2863
            _safe_call(_LIB.LGBM_DatasetSetFeatureNames(
2864
                self._handle,
2865
                _c_array(ctypes.c_char_p, c_feature_name),
Guolin Ke's avatar
Guolin Ke committed
2866
                ctypes.c_int(len(feature_name))))
Nikita Titov's avatar
Nikita Titov committed
2867
        return self
Guolin Ke's avatar
Guolin Ke committed
2868

2869
    def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
2870
        """Set label of Dataset.
Guolin Ke's avatar
Guolin Ke committed
2871
2872
2873

        Parameters
        ----------
2874
        label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None
2875
            The label information to be set into Dataset.
Nikita Titov's avatar
Nikita Titov committed
2876
2877
2878
2879
2880

        Returns
        -------
        self : Dataset
            Dataset with set label.
Guolin Ke's avatar
Guolin Ke committed
2881
2882
        """
        self.label = label
2883
        if self._handle is not None:
2884
2885
2886
            if isinstance(label, pd_DataFrame):
                if len(label.columns) > 1:
                    raise ValueError('DataFrame for label cannot have multiple columns')
2887
                label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32))
2888
2889
            elif _is_pyarrow_array(label):
                label_array = label
2890
            else:
2891
                label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label')
2892
            self.set_field('label', label_array)
2893
            self.label = self.get_field('label')  # original values can be modified at cpp side
Nikita Titov's avatar
Nikita Titov committed
2894
        return self
Guolin Ke's avatar
Guolin Ke committed
2895

2896
2897
2898
2899
    def set_weight(
        self,
        weight: Optional[_LGBM_WeightType]
    ) -> "Dataset":
2900
        """Set weight of each instance.
Guolin Ke's avatar
Guolin Ke committed
2901
2902
2903

        Parameters
        ----------
2904
        weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
2905
            Weight to be set for each data point. Weights should be non-negative.
Nikita Titov's avatar
Nikita Titov committed
2906
2907
2908
2909
2910

        Returns
        -------
        self : Dataset
            Dataset with set weight.
Guolin Ke's avatar
Guolin Ke committed
2911
        """
2912
2913
2914
2915
2916
2917
2918
        # Check if the weight contains values other than one
        if weight is not None:
            if _is_pyarrow_array(weight):
                if pa_compute.all(pa_compute.equal(weight, 1)).as_py():
                    weight = None
            elif np.all(weight == 1):
                weight = None
Guolin Ke's avatar
Guolin Ke committed
2919
        self.weight = weight
2920
2921

        # Set field
2922
        if self._handle is not None and weight is not None:
2923
2924
            if not _is_pyarrow_array(weight):
                weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight')
wxchan's avatar
wxchan committed
2925
            self.set_field('weight', weight)
2926
            self.weight = self.get_field('weight')  # original values can be modified at cpp side
Nikita Titov's avatar
Nikita Titov committed
2927
        return self
Guolin Ke's avatar
Guolin Ke committed
2928

2929
2930
2931
2932
    def set_init_score(
        self,
        init_score: Optional[_LGBM_InitScoreType]
    ) -> "Dataset":
2933
        """Set init score of Booster to start from.
Guolin Ke's avatar
Guolin Ke committed
2934
2935
2936

        Parameters
        ----------
2937
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None
2938
            Init score for Booster.
Nikita Titov's avatar
Nikita Titov committed
2939
2940
2941
2942
2943

        Returns
        -------
        self : Dataset
            Dataset with set init score.
Guolin Ke's avatar
Guolin Ke committed
2944
2945
        """
        self.init_score = init_score
2946
        if self._handle is not None and init_score is not None:
wxchan's avatar
wxchan committed
2947
            self.set_field('init_score', init_score)
2948
            self.init_score = self.get_field('init_score')  # original values can be modified at cpp side
Nikita Titov's avatar
Nikita Titov committed
2949
        return self
Guolin Ke's avatar
Guolin Ke committed
2950

2951
2952
2953
2954
    def set_group(
        self,
        group: Optional[_LGBM_GroupType]
    ) -> "Dataset":
2955
        """Set group size of Dataset (used for ranking).
Guolin Ke's avatar
Guolin Ke committed
2956
2957
2958

        Parameters
        ----------
2959
        group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
2960
2961
2962
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
2963
2964
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
Nikita Titov's avatar
Nikita Titov committed
2965
2966
2967
2968
2969

        Returns
        -------
        self : Dataset
            Dataset with set group.
Guolin Ke's avatar
Guolin Ke committed
2970
2971
        """
        self.group = group
2972
        if self._handle is not None and group is not None:
2973
2974
            if not _is_pyarrow_array(group):
                group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
wxchan's avatar
wxchan committed
2975
            self.set_field('group', group)
2976
2977
2978
2979
            # original values can be modified at cpp side
            constructed_group = self.get_field('group')
            if constructed_group is not None:
                self.group = np.diff(constructed_group)
Nikita Titov's avatar
Nikita Titov committed
2980
        return self
Guolin Ke's avatar
Guolin Ke committed
2981

2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
    def set_position(
        self,
        position: Optional[_LGBM_PositionType]
    ) -> "Dataset":
        """Set position of Dataset (used for ranking).

        Parameters
        ----------
        position : numpy 1-D array, pandas Series or None, optional (default=None)
            Position of items used in unbiased learning-to-rank task.

        Returns
        -------
        self : Dataset
            Dataset with set position.
        """
        self.position = position
        if self._handle is not None and position is not None:
            position = _list_to_1d_numpy(position, dtype=np.int32, name='position')
            self.set_field('position', position)
        return self

3004
    def get_feature_name(self) -> List[str]:
3005
3006
3007
3008
        """Get the names of columns (features) in the Dataset.

        Returns
        -------
3009
        feature_names : list of str
3010
3011
            The names of columns (features) in the Dataset.
        """
3012
        if self._handle is None:
3013
3014
3015
3016
3017
            raise LightGBMError("Cannot get feature_name before construct dataset")
        num_feature = self.num_feature()
        tmp_out_len = ctypes.c_int(0)
        reserved_string_buffer_size = 255
        required_string_buffer_size = ctypes.c_size_t(0)
3018
        string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
3019
        ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
3020
        _safe_call(_LIB.LGBM_DatasetGetFeatureNames(
3021
            self._handle,
3022
            ctypes.c_int(num_feature),
3023
            ctypes.byref(tmp_out_len),
3024
            ctypes.c_size_t(reserved_string_buffer_size),
3025
3026
3027
3028
            ctypes.byref(required_string_buffer_size),
            ptr_string_buffers))
        if num_feature != tmp_out_len.value:
            raise ValueError("Length of feature names doesn't equal with num_feature")
3029
3030
3031
3032
        actual_string_buffer_size = required_string_buffer_size.value
        # if buffer length is not long enough, reallocate buffers
        if reserved_string_buffer_size < actual_string_buffer_size:
            string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
3033
            ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
3034
            _safe_call(_LIB.LGBM_DatasetGetFeatureNames(
3035
                self._handle,
3036
3037
3038
3039
3040
                ctypes.c_int(num_feature),
                ctypes.byref(tmp_out_len),
                ctypes.c_size_t(actual_string_buffer_size),
                ctypes.byref(required_string_buffer_size),
                ptr_string_buffers))
3041
        return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]
3042

3043
    def get_label(self) -> Optional[_LGBM_LabelType]:
3044
        """Get the label of the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3045
3046
3047

        Returns
        -------
3048
        label : list, numpy 1-D array, pandas Series / one-column DataFrame or None
3049
            The label information from the Dataset.
3050
            For a constructed ``Dataset``, this will only return a numpy array.
Guolin Ke's avatar
Guolin Ke committed
3051
        """
3052
        if self.label is None:
wxchan's avatar
wxchan committed
3053
            self.label = self.get_field('label')
Guolin Ke's avatar
Guolin Ke committed
3054
3055
        return self.label

3056
    def get_weight(self) -> Optional[_LGBM_WeightType]:
3057
        """Get the weight of the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3058
3059
3060

        Returns
        -------
3061
        weight : list, numpy 1-D array, pandas Series or None
3062
            Weight for each data point from the Dataset. Weights should be non-negative.
3063
            For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
Guolin Ke's avatar
Guolin Ke committed
3064
        """
3065
        if self.weight is None:
wxchan's avatar
wxchan committed
3066
            self.weight = self.get_field('weight')
Guolin Ke's avatar
Guolin Ke committed
3067
3068
        return self.weight

3069
    def get_init_score(self) -> Optional[_LGBM_InitScoreType]:
3070
        """Get the initial score of the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3071
3072
3073

        Returns
        -------
3074
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None
3075
            Init score of Booster.
3076
            For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
Guolin Ke's avatar
Guolin Ke committed
3077
        """
3078
        if self.init_score is None:
wxchan's avatar
wxchan committed
3079
            self.init_score = self.get_field('init_score')
Guolin Ke's avatar
Guolin Ke committed
3080
3081
        return self.init_score

3082
    def get_data(self) -> Optional[_LGBM_TrainDataType]:
3083
3084
3085
3086
        """Get the raw data of the Dataset.

        Returns
        -------
3087
        data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array or None
3088
3089
            Raw data used in the Dataset construction.
        """
3090
        if self._handle is None:
3091
            raise Exception("Cannot get data before construct Dataset")
3092
        if self._need_slice and self.used_indices is not None and self.reference is not None:
Guolin Ke's avatar
Guolin Ke committed
3093
3094
            self.data = self.reference.data
            if self.data is not None:
3095
                if isinstance(self.data, np.ndarray) or isinstance(self.data, scipy.sparse.spmatrix):
Guolin Ke's avatar
Guolin Ke committed
3096
                    self.data = self.data[self.used_indices, :]
3097
                elif isinstance(self.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3098
                    self.data = self.data.iloc[self.used_indices].copy()
3099
                elif isinstance(self.data, dt_DataTable):
Guolin Ke's avatar
Guolin Ke committed
3100
                    self.data = self.data[self.used_indices, :]
3101
3102
                elif isinstance(self.data, Sequence):
                    self.data = self.data[self.used_indices]
3103
                elif _is_list_of_sequences(self.data) and len(self.data) > 0:
3104
                    self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
Guolin Ke's avatar
Guolin Ke committed
3105
                else:
3106
3107
                    _log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n"
                                 "Returning original raw data")
3108
            self._need_slice = False
Guolin Ke's avatar
Guolin Ke committed
3109
3110
3111
        if self.data is None:
            raise LightGBMError("Cannot call `get_data` after freed raw data, "
                                "set free_raw_data=False when construct Dataset to avoid this.")
3112
3113
        return self.data

3114
    def get_group(self) -> Optional[_LGBM_GroupType]:
3115
        """Get the group of the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3116
3117
3118

        Returns
        -------
3119
        group : list, numpy 1-D array, pandas Series or None
3120
3121
3122
            Group/query data.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
3123
3124
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
3125
            For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
Guolin Ke's avatar
Guolin Ke committed
3126
        """
3127
        if self.group is None:
wxchan's avatar
wxchan committed
3128
            self.group = self.get_field('group')
Guolin Ke's avatar
Guolin Ke committed
3129
3130
            if self.group is not None:
                # group data from LightGBM is boundaries data, need to convert to group size
Nikita Titov's avatar
Nikita Titov committed
3131
                self.group = np.diff(self.group)
Guolin Ke's avatar
Guolin Ke committed
3132
3133
        return self.group

3134
    def get_position(self) -> Optional[_LGBM_PositionType]:
3135
3136
3137
3138
        """Get the position of the Dataset.

        Returns
        -------
3139
        position : numpy 1-D array, pandas Series or None
3140
            Position of items used in unbiased learning-to-rank task.
3141
            For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
3142
3143
3144
3145
3146
        """
        if self.position is None:
            self.position = self.get_field('position')
        return self.position

3147
    def num_data(self) -> int:
3148
        """Get the number of rows in the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3149
3150
3151

        Returns
        -------
3152
3153
        number_of_rows : int
            The number of rows in the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3154
        """
3155
        if self._handle is not None:
3156
            ret = ctypes.c_int(0)
3157
            _safe_call(_LIB.LGBM_DatasetGetNumData(self._handle,
wxchan's avatar
wxchan committed
3158
3159
                                                   ctypes.byref(ret)))
            return ret.value
Guolin Ke's avatar
Guolin Ke committed
3160
        else:
3161
            raise LightGBMError("Cannot get num_data before construct dataset")
Guolin Ke's avatar
Guolin Ke committed
3162

3163
    def num_feature(self) -> int:
3164
        """Get the number of columns (features) in the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3165
3166
3167

        Returns
        -------
3168
3169
        number_of_columns : int
            The number of columns (features) in the Dataset.
Guolin Ke's avatar
Guolin Ke committed
3170
        """
3171
        if self._handle is not None:
3172
            ret = ctypes.c_int(0)
3173
            _safe_call(_LIB.LGBM_DatasetGetNumFeature(self._handle,
wxchan's avatar
wxchan committed
3174
3175
                                                      ctypes.byref(ret)))
            return ret.value
Guolin Ke's avatar
Guolin Ke committed
3176
        else:
3177
            raise LightGBMError("Cannot get num_feature before construct dataset")
Guolin Ke's avatar
Guolin Ke committed
3178

3179
    def feature_num_bin(self, feature: Union[int, str]) -> int:
3180
3181
        """Get the number of bins for a feature.

3182
3183
        .. versionadded:: 4.0.0

3184
3185
        Parameters
        ----------
3186
3187
        feature : int or str
            Index or name of the feature.
3188
3189
3190
3191
3192
3193

        Returns
        -------
        number_of_bins : int
            The number of constructed bins for the feature in the Dataset.
        """
3194
        if self._handle is not None:
3195
            if isinstance(feature, str):
3196
3197
3198
                feature_index = self.feature_name.index(feature)
            else:
                feature_index = feature
3199
            ret = ctypes.c_int(0)
3200
            _safe_call(_LIB.LGBM_DatasetGetFeatureNumBin(self._handle,
3201
                                                         ctypes.c_int(feature_index),
3202
3203
3204
3205
3206
                                                         ctypes.byref(ret)))
            return ret.value
        else:
            raise LightGBMError("Cannot get feature_num_bin before construct dataset")

3207
    def get_ref_chain(self, ref_limit: int = 100) -> Set["Dataset"]:
3208
3209
3210
3211
3212
        """Get a chain of Dataset objects.

        Starts with r, then goes to r.reference (if exists),
        then to r.reference.reference, etc.
        until we hit ``ref_limit`` or a reference loop.
3213
3214
3215
3216
3217

        Parameters
        ----------
        ref_limit : int, optional (default=100)
            The limit number of references.
3218
3219
3220

        Returns
        -------
3221
3222
3223
        ref_chain : set of Dataset
            Chain of references of the Datasets.
        """
3224
        head = self
3225
        ref_chain: Set[Dataset] = set()
3226
3227
        while len(ref_chain) < ref_limit:
            if isinstance(head, Dataset):
3228
                ref_chain.add(head)
3229
3230
3231
3232
3233
3234
                if (head.reference is not None) and (head.reference not in ref_chain):
                    head = head.reference
                else:
                    break
            else:
                break
Nikita Titov's avatar
Nikita Titov committed
3235
        return ref_chain
3236

3237
    def add_features_from(self, other: "Dataset") -> "Dataset":
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
        """Add features from other Dataset to the current Dataset.

        Both Datasets must be constructed before calling this method.

        Parameters
        ----------
        other : Dataset
            The Dataset to take features from.

        Returns
        -------
        self : Dataset
            Dataset with the new features added.
        """
3252
        if self._handle is None or other._handle is None:
3253
            raise ValueError('Both source and target Datasets must be constructed before adding features')
3254
        _safe_call(_LIB.LGBM_DatasetAddFeaturesFrom(self._handle, other._handle))
Guolin Ke's avatar
Guolin Ke committed
3255
3256
3257
3258
3259
3260
3261
3262
        was_none = self.data is None
        old_self_data_type = type(self.data).__name__
        if other.data is None:
            self.data = None
        elif self.data is not None:
            if isinstance(self.data, np.ndarray):
                if isinstance(other.data, np.ndarray):
                    self.data = np.hstack((self.data, other.data))
3263
                elif isinstance(other.data, scipy.sparse.spmatrix):
Guolin Ke's avatar
Guolin Ke committed
3264
                    self.data = np.hstack((self.data, other.data.toarray()))
3265
                elif isinstance(other.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3266
                    self.data = np.hstack((self.data, other.data.values))
3267
                elif isinstance(other.data, dt_DataTable):
Guolin Ke's avatar
Guolin Ke committed
3268
3269
3270
                    self.data = np.hstack((self.data, other.data.to_numpy()))
                else:
                    self.data = None
3271
            elif isinstance(self.data, scipy.sparse.spmatrix):
Guolin Ke's avatar
Guolin Ke committed
3272
                sparse_format = self.data.getformat()
3273
                if isinstance(other.data, np.ndarray) or isinstance(other.data, scipy.sparse.spmatrix):
Guolin Ke's avatar
Guolin Ke committed
3274
                    self.data = scipy.sparse.hstack((self.data, other.data), format=sparse_format)
3275
                elif isinstance(other.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3276
                    self.data = scipy.sparse.hstack((self.data, other.data.values), format=sparse_format)
3277
                elif isinstance(other.data, dt_DataTable):
Guolin Ke's avatar
Guolin Ke committed
3278
3279
3280
                    self.data = scipy.sparse.hstack((self.data, other.data.to_numpy()), format=sparse_format)
                else:
                    self.data = None
3281
            elif isinstance(self.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3282
3283
                if not PANDAS_INSTALLED:
                    raise LightGBMError("Cannot add features to DataFrame type of raw data "
3284
3285
                                        "without pandas installed. "
                                        "Install pandas and restart your session.")
Guolin Ke's avatar
Guolin Ke committed
3286
                if isinstance(other.data, np.ndarray):
3287
                    self.data = concat((self.data, pd_DataFrame(other.data)),
Guolin Ke's avatar
Guolin Ke committed
3288
                                       axis=1, ignore_index=True)
3289
                elif isinstance(other.data, scipy.sparse.spmatrix):
3290
                    self.data = concat((self.data, pd_DataFrame(other.data.toarray())),
Guolin Ke's avatar
Guolin Ke committed
3291
                                       axis=1, ignore_index=True)
3292
                elif isinstance(other.data, pd_DataFrame):
Guolin Ke's avatar
Guolin Ke committed
3293
3294
                    self.data = concat((self.data, other.data),
                                       axis=1, ignore_index=True)
3295
3296
                elif isinstance(other.data, dt_DataTable):
                    self.data = concat((self.data, pd_DataFrame(other.data.to_numpy())),
Guolin Ke's avatar
Guolin Ke committed
3297
3298
3299
                                       axis=1, ignore_index=True)
                else:
                    self.data = None
3300
            elif isinstance(self.data, dt_DataTable):
Guolin Ke's avatar
Guolin Ke committed
3301
                if isinstance(other.data, np.ndarray):
3302
                    self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data)))
3303
                elif isinstance(other.data, scipy.sparse.spmatrix):
3304
3305
3306
3307
3308
                    self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.toarray())))
                elif isinstance(other.data, pd_DataFrame):
                    self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.values)))
                elif isinstance(other.data, dt_DataTable):
                    self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.to_numpy())))
Guolin Ke's avatar
Guolin Ke committed
3309
3310
3311
3312
3313
                else:
                    self.data = None
            else:
                self.data = None
        if self.data is None:
3314
3315
            err_msg = (f"Cannot add features from {type(other.data).__name__} type of raw data to "
                       f"{old_self_data_type} type of raw data.\n")
Guolin Ke's avatar
Guolin Ke committed
3316
3317
            err_msg += ("Set free_raw_data=False when construct Dataset to avoid this"
                        if was_none else "Freeing raw data")
3318
            _log_warning(err_msg)
Guolin Ke's avatar
Guolin Ke committed
3319
        self.feature_name = self.get_feature_name()
3320
3321
        _log_warning("Reseting categorical features.\n"
                     "You can set new categorical features via ``set_categorical_feature`` method")
Guolin Ke's avatar
Guolin Ke committed
3322
3323
        self.categorical_feature = "auto"
        self.pandas_categorical = None
3324
3325
        return self

3326
    def _dump_text(self, filename: Union[str, Path]) -> "Dataset":
3327
3328
3329
3330
3331
3332
        """Save Dataset to a text file.

        This format cannot be loaded back in by LightGBM, but is useful for debugging purposes.

        Parameters
        ----------
3333
        filename : str or pathlib.Path
3334
3335
3336
3337
3338
3339
3340
3341
            Name of the output file.

        Returns
        -------
        self : Dataset
            Returns self.
        """
        _safe_call(_LIB.LGBM_DatasetDumpText(
3342
            self.construct()._handle,
3343
            _c_str(str(filename))))
3344
3345
        return self

wxchan's avatar
wxchan committed
3346

3347
3348
3349
3350
_LGBM_CustomObjectiveFunction = Callable[
    [np.ndarray, Dataset],
    Tuple[np.ndarray, np.ndarray]
]
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
_LGBM_CustomEvalFunction = Union[
    Callable[
        [np.ndarray, Dataset],
        _LGBM_EvalFunctionResultType
    ],
    Callable[
        [np.ndarray, Dataset],
        List[_LGBM_EvalFunctionResultType]
    ]
]
3361
3362


3363
class Booster:
3364
    """Booster in LightGBM."""
3365

3366
3367
3368
3369
3370
3371
3372
    def __init__(
        self,
        params: Optional[Dict[str, Any]] = None,
        train_set: Optional[Dataset] = None,
        model_file: Optional[Union[str, Path]] = None,
        model_str: Optional[str] = None
    ):
3373
        """Initialize the Booster.
wxchan's avatar
wxchan committed
3374
3375
3376

        Parameters
        ----------
Nikita Titov's avatar
Nikita Titov committed
3377
        params : dict or None, optional (default=None)
3378
3379
3380
            Parameters for Booster.
        train_set : Dataset or None, optional (default=None)
            Training dataset.
3381
        model_file : str, pathlib.Path or None, optional (default=None)
wxchan's avatar
wxchan committed
3382
            Path to the model file.
3383
        model_str : str or None, optional (default=None)
3384
            Model will be loaded from this string.
wxchan's avatar
wxchan committed
3385
        """
3386
        self._handle = ctypes.c_void_p()
3387
        self._network = False
wxchan's avatar
wxchan committed
3388
        self.__need_reload_eval_info = True
3389
        self._train_data_name = "training"
3390
        self.__set_objective_to_none = False
wxchan's avatar
wxchan committed
3391
        self.best_iteration = -1
3392
        self.best_score: _LGBM_BoosterBestScoreType = {}
3393
        params = {} if params is None else deepcopy(params)
wxchan's avatar
wxchan committed
3394
        if train_set is not None:
3395
            # Training task
wxchan's avatar
wxchan committed
3396
            if not isinstance(train_set, Dataset):
3397
                raise TypeError(f'Training data should be Dataset instance, met {type(train_set).__name__}')
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
            params = _choose_param_value(
                main_param_name="machines",
                params=params,
                default_value=None
            )
            # if "machines" is given, assume user wants to do distributed learning, and set up network
            if params["machines"] is None:
                params.pop("machines", None)
            else:
                machines = params["machines"]
                if isinstance(machines, str):
                    num_machines_from_machine_list = len(machines.split(','))
                elif isinstance(machines, (list, set)):
                    num_machines_from_machine_list = len(machines)
                    machines = ','.join(machines)
                else:
                    raise ValueError("Invalid machines in params.")

                params = _choose_param_value(
                    main_param_name="num_machines",
                    params=params,
                    default_value=num_machines_from_machine_list
                )
                params = _choose_param_value(
                    main_param_name="local_listen_port",
                    params=params,
                    default_value=12400
                )
                self.set_network(
                    machines=machines,
                    local_listen_port=params["local_listen_port"],
                    listen_time_out=params.get("time_out", 120),
                    num_machines=params["num_machines"]
                )
3432
            # construct booster object
3433
3434
3435
            train_set.construct()
            # copy the parameters from train_set
            params.update(train_set.get_params())
3436
            params_str = _param_dict_to_str(params)
wxchan's avatar
wxchan committed
3437
            _safe_call(_LIB.LGBM_BoosterCreate(
3438
                train_set._handle,
3439
                _c_str(params_str),
3440
                ctypes.byref(self._handle)))
3441
            # save reference to data
wxchan's avatar
wxchan committed
3442
            self.train_set = train_set
3443
3444
            self.valid_sets: List[Dataset] = []
            self.name_valid_sets: List[str] = []
wxchan's avatar
wxchan committed
3445
            self.__num_dataset = 1
Guolin Ke's avatar
Guolin Ke committed
3446
3447
            self.__init_predictor = train_set._predictor
            if self.__init_predictor is not None:
wxchan's avatar
wxchan committed
3448
                _safe_call(_LIB.LGBM_BoosterMerge(
3449
3450
                    self._handle,
                    self.__init_predictor._handle))
Guolin Ke's avatar
Guolin Ke committed
3451
            out_num_class = ctypes.c_int(0)
wxchan's avatar
wxchan committed
3452
            _safe_call(_LIB.LGBM_BoosterGetNumClasses(
3453
                self._handle,
wxchan's avatar
wxchan committed
3454
3455
                ctypes.byref(out_num_class)))
            self.__num_class = out_num_class.value
3456
            # buffer for inner predict
3457
            self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None]
wxchan's avatar
wxchan committed
3458
3459
            self.__is_predicted_cur_iter = [False]
            self.__get_eval_info()
3460
            self.pandas_categorical = train_set.pandas_categorical
3461
            self.train_set_version = train_set.version
wxchan's avatar
wxchan committed
3462
        elif model_file is not None:
3463
            # Prediction task
Guolin Ke's avatar
Guolin Ke committed
3464
            out_num_iterations = ctypes.c_int(0)
wxchan's avatar
wxchan committed
3465
            _safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
3466
                _c_str(str(model_file)),
wxchan's avatar
wxchan committed
3467
                ctypes.byref(out_num_iterations),
3468
                ctypes.byref(self._handle)))
Guolin Ke's avatar
Guolin Ke committed
3469
            out_num_class = ctypes.c_int(0)
wxchan's avatar
wxchan committed
3470
            _safe_call(_LIB.LGBM_BoosterGetNumClasses(
3471
                self._handle,
wxchan's avatar
wxchan committed
3472
3473
                ctypes.byref(out_num_class)))
            self.__num_class = out_num_class.value
3474
            self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
3475
3476
3477
            if params:
                _log_warning('Ignoring params argument, using parameters from model file.')
            params = self._get_loaded_param()
3478
        elif model_str is not None:
3479
            self.model_from_string(model_str)
wxchan's avatar
wxchan committed
3480
        else:
3481
3482
            raise TypeError('Need at least one training dataset or model file or model string '
                            'to create Booster instance')
3483
        self.params = params
wxchan's avatar
wxchan committed
3484

3485
    def __del__(self) -> None:
3486
        try:
3487
            if self._network:
3488
3489
3490
3491
                self.free_network()
        except AttributeError:
            pass
        try:
3492
3493
            if self._handle is not None:
                _safe_call(_LIB.LGBM_BoosterFree(self._handle))
3494
3495
        except AttributeError:
            pass
wxchan's avatar
wxchan committed
3496

3497
    def __copy__(self) -> "Booster":
wxchan's avatar
wxchan committed
3498
3499
        return self.__deepcopy__(None)

3500
    def __deepcopy__(self, _) -> "Booster":
3501
        model_str = self.model_to_string(num_iteration=-1)
3502
        return Booster(model_str=model_str)
wxchan's avatar
wxchan committed
3503

3504
    def __getstate__(self) -> Dict[str, Any]:
wxchan's avatar
wxchan committed
3505
        this = self.__dict__.copy()
3506
        handle = this['_handle']
wxchan's avatar
wxchan committed
3507
3508
3509
        this.pop('train_set', None)
        this.pop('valid_sets', None)
        if handle is not None:
3510
            this["_handle"] = self.model_to_string(num_iteration=-1)
wxchan's avatar
wxchan committed
3511
3512
        return this

3513
    def __setstate__(self, state: Dict[str, Any]) -> None:
3514
        model_str = state.get('_handle', state.get('handle', None))
3515
        if model_str is not None:
wxchan's avatar
wxchan committed
3516
            handle = ctypes.c_void_p()
Guolin Ke's avatar
Guolin Ke committed
3517
            out_num_iterations = ctypes.c_int(0)
3518
            _safe_call(_LIB.LGBM_BoosterLoadModelFromString(
3519
                _c_str(model_str),
3520
3521
                ctypes.byref(out_num_iterations),
                ctypes.byref(handle)))
3522
            state['_handle'] = handle
wxchan's avatar
wxchan committed
3523
3524
        self.__dict__.update(state)

3525
3526
3527
3528
    def _get_loaded_param(self) -> Dict[str, Any]:
        buffer_len = 1 << 20
        tmp_out_len = ctypes.c_int64(0)
        string_buffer = ctypes.create_string_buffer(buffer_len)
3529
        ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
3530
        _safe_call(_LIB.LGBM_BoosterGetLoadedParam(
3531
            self._handle,
3532
3533
3534
3535
3536
3537
3538
            ctypes.c_int64(buffer_len),
            ctypes.byref(tmp_out_len),
            ptr_string_buffer))
        actual_len = tmp_out_len.value
        # if buffer length is not long enough, re-allocate a buffer
        if actual_len > buffer_len:
            string_buffer = ctypes.create_string_buffer(actual_len)
3539
            ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
3540
            _safe_call(_LIB.LGBM_BoosterGetLoadedParam(
3541
                self._handle,
3542
3543
3544
3545
3546
                ctypes.c_int64(actual_len),
                ctypes.byref(tmp_out_len),
                ptr_string_buffer))
        return json.loads(string_buffer.value.decode('utf-8'))

3547
    def free_dataset(self) -> "Booster":
Nikita Titov's avatar
Nikita Titov committed
3548
3549
3550
3551
3552
3553
3554
        """Free Booster's Datasets.

        Returns
        -------
        self : Booster
            Booster without Datasets.
        """
wxchan's avatar
wxchan committed
3555
3556
        self.__dict__.pop('train_set', None)
        self.__dict__.pop('valid_sets', None)
3557
        self.__num_dataset = 0
Nikita Titov's avatar
Nikita Titov committed
3558
        return self
wxchan's avatar
wxchan committed
3559

3560
    def _free_buffer(self) -> "Booster":
3561
3562
        self.__inner_predict_buffer = []
        self.__is_predicted_cur_iter = []
Nikita Titov's avatar
Nikita Titov committed
3563
        return self
3564

3565
3566
3567
3568
3569
3570
3571
    def set_network(
        self,
        machines: Union[List[str], Set[str], str],
        local_listen_port: int = 12400,
        listen_time_out: int = 120,
        num_machines: int = 1
    ) -> "Booster":
3572
3573
3574
3575
        """Set the network configuration.

        Parameters
        ----------
3576
        machines : list, set or str
3577
            Names of machines.
Nikita Titov's avatar
Nikita Titov committed
3578
        local_listen_port : int, optional (default=12400)
3579
            TCP listen port for local machines.
Nikita Titov's avatar
Nikita Titov committed
3580
        listen_time_out : int, optional (default=120)
3581
            Socket time-out in minutes.
Nikita Titov's avatar
Nikita Titov committed
3582
        num_machines : int, optional (default=1)
3583
            The number of machines for distributed learning application.
Nikita Titov's avatar
Nikita Titov committed
3584
3585
3586
3587
3588

        Returns
        -------
        self : Booster
            Booster with set network.
3589
        """
3590
3591
        if isinstance(machines, (list, set)):
            machines = ','.join(machines)
3592
        _safe_call(_LIB.LGBM_NetworkInit(_c_str(machines),
3593
3594
3595
                                         ctypes.c_int(local_listen_port),
                                         ctypes.c_int(listen_time_out),
                                         ctypes.c_int(num_machines)))
3596
        self._network = True
Nikita Titov's avatar
Nikita Titov committed
3597
        return self
3598

3599
    def free_network(self) -> "Booster":
Nikita Titov's avatar
Nikita Titov committed
3600
3601
3602
3603
3604
3605
3606
        """Free Booster's network.

        Returns
        -------
        self : Booster
            Booster with freed network.
        """
3607
        _safe_call(_LIB.LGBM_NetworkFree())
3608
        self._network = False
Nikita Titov's avatar
Nikita Titov committed
3609
        return self
3610

3611
    def trees_to_dataframe(self) -> pd_DataFrame:
3612
3613
        """Parse the fitted model and return in an easy-to-read pandas DataFrame.

3614
3615
3616
3617
        The returned DataFrame has the following columns.

            - ``tree_index`` : int64, which tree a node belongs to. 0-based, so a value of ``6``, for example, means "this node is in the 7th tree".
            - ``node_depth`` : int64, how far a node is from the root of the tree. The root node has a value of ``1``, its direct children are ``2``, etc.
3618
3619
3620
3621
3622
            - ``node_index`` : str, unique identifier for a node.
            - ``left_child`` : str, ``node_index`` of the child node to the left of a split. ``None`` for leaf nodes.
            - ``right_child`` : str, ``node_index`` of the child node to the right of a split. ``None`` for leaf nodes.
            - ``parent_index`` : str, ``node_index`` of this node's parent. ``None`` for the root node.
            - ``split_feature`` : str, name of the feature used for splitting. ``None`` for leaf nodes.
3623
3624
            - ``split_gain`` : float64, gain from adding this split to the tree. ``NaN`` for leaf nodes.
            - ``threshold`` : float64, value of the feature used to decide which side of the split a record will go down. ``NaN`` for leaf nodes.
3625
            - ``decision_type`` : str, logical operator describing how to compare a value to ``threshold``.
3626
3627
              For example, ``split_feature = "Column_10", threshold = 15, decision_type = "<="`` means that
              records where ``Column_10 <= 15`` follow the left side of the split, otherwise follows the right side of the split. ``None`` for leaf nodes.
3628
3629
            - ``missing_direction`` : str, split direction that missing values should go to. ``None`` for leaf nodes.
            - ``missing_type`` : str, describes what types of values are treated as missing.
3630
            - ``value`` : float64, predicted value for this leaf node, multiplied by the learning rate.
3631
            - ``weight`` : float64 or int64, sum of Hessian (second-order derivative of objective), summed over observations that fall in this node.
3632
3633
            - ``count`` : int64, number of records in the training data that fall into this node.

3634
3635
3636
3637
3638
3639
        Returns
        -------
        result : pandas DataFrame
            Returns a pandas DataFrame of the parsed model.
        """
        if not PANDAS_INSTALLED:
3640
3641
            raise LightGBMError('This method cannot be run without pandas installed. '
                                'You must install pandas and restart your session to use this method.')
3642
3643
3644
3645

        if self.num_trees() == 0:
            raise LightGBMError('There are no trees in this Booster and thus nothing to parse')

3646
        def _is_split_node(tree: Dict[str, Any]) -> bool:
3647
3648
            return 'split_index' in tree.keys()

3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
        def create_node_record(
            tree: Dict[str, Any],
            node_depth: int = 1,
            tree_index: Optional[int] = None,
            feature_names: Optional[List[str]] = None,
            parent_node: Optional[str] = None
        ) -> Dict[str, Any]:

            def _get_node_index(
                tree: Dict[str, Any],
                tree_index: Optional[int]
            ) -> str:
3661
                tree_num = f'{tree_index}-' if tree_index is not None else ''
3662
3663
3664
                is_split = _is_split_node(tree)
                node_type = 'S' if is_split else 'L'
                # if a single node tree it won't have `leaf_index` so return 0
3665
3666
                node_num = tree.get('split_index' if is_split else 'leaf_index', 0)
                return f"{tree_num}{node_type}{node_num}"
3667

3668
3669
3670
3671
            def _get_split_feature(
                tree: Dict[str, Any],
                feature_names: Optional[List[str]]
            ) -> Optional[str]:
3672
3673
3674
3675
3676
3677
3678
3679
3680
                if _is_split_node(tree):
                    if feature_names is not None:
                        feature_name = feature_names[tree['split_feature']]
                    else:
                        feature_name = tree['split_feature']
                else:
                    feature_name = None
                return feature_name

3681
            def _is_single_node_tree(tree: Dict[str, Any]) -> bool:
3682
                return set(tree.keys()) == {'leaf_value'}
3683
3684

            # Create the node record, and populate universal data members
3685
            node: Dict[str, Union[int, str, None]] = OrderedDict()
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
            node['tree_index'] = tree_index
            node['node_depth'] = node_depth
            node['node_index'] = _get_node_index(tree, tree_index)
            node['left_child'] = None
            node['right_child'] = None
            node['parent_index'] = parent_node
            node['split_feature'] = _get_split_feature(tree, feature_names)
            node['split_gain'] = None
            node['threshold'] = None
            node['decision_type'] = None
            node['missing_direction'] = None
            node['missing_type'] = None
            node['value'] = None
            node['weight'] = None
            node['count'] = None

            # Update values to reflect node type (leaf or split)
            if _is_split_node(tree):
                node['left_child'] = _get_node_index(tree['left_child'], tree_index)
                node['right_child'] = _get_node_index(tree['right_child'], tree_index)
                node['split_gain'] = tree['split_gain']
                node['threshold'] = tree['threshold']
                node['decision_type'] = tree['decision_type']
                node['missing_direction'] = 'left' if tree['default_left'] else 'right'
                node['missing_type'] = tree['missing_type']
                node['value'] = tree['internal_value']
                node['weight'] = tree['internal_weight']
                node['count'] = tree['internal_count']
            else:
                node['value'] = tree['leaf_value']
                if not _is_single_node_tree(tree):
                    node['weight'] = tree['leaf_weight']
                    node['count'] = tree['leaf_count']

            return node

3722
3723
3724
3725
3726
3727
3728
        def tree_dict_to_node_list(
            tree: Dict[str, Any],
            node_depth: int = 1,
            tree_index: Optional[int] = None,
            feature_names: Optional[List[str]] = None,
            parent_node: Optional[str] = None
        ) -> List[Dict[str, Any]]:
3729

3730
            node = create_node_record(tree=tree,
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
                                      node_depth=node_depth,
                                      tree_index=tree_index,
                                      feature_names=feature_names,
                                      parent_node=parent_node)

            res = [node]

            if _is_split_node(tree):
                # traverse the next level of the tree
                children = ['left_child', 'right_child']
                for child in children:
                    subtree_list = tree_dict_to_node_list(
3743
                        tree=tree[child],
3744
3745
3746
                        node_depth=node_depth + 1,
                        tree_index=tree_index,
                        feature_names=feature_names,
3747
3748
                        parent_node=node['node_index']
                    )
3749
3750
3751
3752
3753
3754
3755
3756
3757
                    # In tree format, "subtree_list" is a list of node records (dicts),
                    # and we add node to the list.
                    res.extend(subtree_list)
            return res

        model_dict = self.dump_model()
        feature_names = model_dict['feature_names']
        model_list = []
        for tree in model_dict['tree_info']:
3758
            model_list.extend(tree_dict_to_node_list(tree=tree['tree_structure'],
3759
3760
3761
                                                     tree_index=tree['tree_index'],
                                                     feature_names=feature_names))

3762
        return pd_DataFrame(model_list, columns=model_list[0].keys())
3763

3764
    def set_train_data_name(self, name: str) -> "Booster":
3765
3766
3767
3768
        """Set the name to the training Dataset.

        Parameters
        ----------
3769
        name : str
Nikita Titov's avatar
Nikita Titov committed
3770
3771
3772
3773
3774
3775
            Name for the training Dataset.

        Returns
        -------
        self : Booster
            Booster with set training Dataset name.
3776
        """
3777
        self._train_data_name = name
Nikita Titov's avatar
Nikita Titov committed
3778
        return self
wxchan's avatar
wxchan committed
3779

3780
    def add_valid(self, data: Dataset, name: str) -> "Booster":
3781
        """Add validation data.
wxchan's avatar
wxchan committed
3782
3783
3784
3785

        Parameters
        ----------
        data : Dataset
3786
            Validation data.
3787
        name : str
3788
            Name of validation data.
Nikita Titov's avatar
Nikita Titov committed
3789
3790
3791
3792
3793

        Returns
        -------
        self : Booster
            Booster with set validation data.
wxchan's avatar
wxchan committed
3794
        """
Guolin Ke's avatar
Guolin Ke committed
3795
        if not isinstance(data, Dataset):
3796
            raise TypeError(f'Validation data should be Dataset instance, met {type(data).__name__}')
Guolin Ke's avatar
Guolin Ke committed
3797
        if data._predictor is not self.__init_predictor:
3798
3799
            raise LightGBMError("Add validation data failed, "
                                "you should use same predictor for these data")
wxchan's avatar
wxchan committed
3800
        _safe_call(_LIB.LGBM_BoosterAddValidData(
3801
3802
            self._handle,
            data.construct()._handle))
wxchan's avatar
wxchan committed
3803
3804
3805
3806
3807
        self.valid_sets.append(data)
        self.name_valid_sets.append(name)
        self.__num_dataset += 1
        self.__inner_predict_buffer.append(None)
        self.__is_predicted_cur_iter.append(False)
Nikita Titov's avatar
Nikita Titov committed
3808
        return self
wxchan's avatar
wxchan committed
3809

3810
    def reset_parameter(self, params: Dict[str, Any]) -> "Booster":
3811
        """Reset parameters of Booster.
wxchan's avatar
wxchan committed
3812
3813
3814
3815

        Parameters
        ----------
        params : dict
3816
            New parameters for Booster.
Nikita Titov's avatar
Nikita Titov committed
3817
3818
3819
3820
3821

        Returns
        -------
        self : Booster
            Booster with new parameters.
wxchan's avatar
wxchan committed
3822
        """
3823
        params_str = _param_dict_to_str(params)
wxchan's avatar
wxchan committed
3824
3825
        if params_str:
            _safe_call(_LIB.LGBM_BoosterResetParameter(
3826
                self._handle,
3827
                _c_str(params_str)))
Guolin Ke's avatar
Guolin Ke committed
3828
        self.params.update(params)
Nikita Titov's avatar
Nikita Titov committed
3829
        return self
wxchan's avatar
wxchan committed
3830

3831
3832
3833
3834
3835
    def update(
        self,
        train_set: Optional[Dataset] = None,
        fobj: Optional[_LGBM_CustomObjectiveFunction] = None
    ) -> bool:
Nikita Titov's avatar
Nikita Titov committed
3836
        """Update Booster for one iteration.
3837

wxchan's avatar
wxchan committed
3838
3839
        Parameters
        ----------
3840
3841
3842
3843
        train_set : Dataset or None, optional (default=None)
            Training data.
            If None, last training data is used.
        fobj : callable or None, optional (default=None)
wxchan's avatar
wxchan committed
3844
            Customized objective function.
3845
3846
3847
            Should accept two parameters: preds, train_data,
            and return (grad, hess).

3848
                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
3849
                    The predicted values.
3850
3851
                    Predicted values are returned before any transformation,
                    e.g. they are raw margin instead of probability of positive class for binary task.
3852
3853
                train_data : Dataset
                    The training dataset.
3854
                grad : numpy 1-D array or numpy 2-D array (for multi-class task)
3855
3856
                    The value of the first order derivative (gradient) of the loss
                    with respect to the elements of preds for each sample point.
3857
                hess : numpy 1-D array or numpy 2-D array (for multi-class task)
3858
3859
                    The value of the second order derivative (Hessian) of the loss
                    with respect to the elements of preds for each sample point.
wxchan's avatar
wxchan committed
3860

3861
            For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes],
3862
            and grad and hess should be returned in the same format.
3863

wxchan's avatar
wxchan committed
3864
3865
        Returns
        -------
3866
3867
        is_finished : bool
            Whether the update was successfully finished.
wxchan's avatar
wxchan committed
3868
        """
3869
        # need reset training data
3870
3871
3872
3873
3874
3875
        if train_set is None and self.train_set_version != self.train_set.version:
            train_set = self.train_set
            is_the_same_train_set = False
        else:
            is_the_same_train_set = train_set is self.train_set and self.train_set_version == train_set.version
        if train_set is not None and not is_the_same_train_set:
Guolin Ke's avatar
Guolin Ke committed
3876
            if not isinstance(train_set, Dataset):
3877
                raise TypeError(f'Training data should be Dataset instance, met {type(train_set).__name__}')
Guolin Ke's avatar
Guolin Ke committed
3878
            if train_set._predictor is not self.__init_predictor:
3879
3880
                raise LightGBMError("Replace training data failed, "
                                    "you should use same predictor for these data")
wxchan's avatar
wxchan committed
3881
3882
            self.train_set = train_set
            _safe_call(_LIB.LGBM_BoosterResetTrainingData(
3883
3884
                self._handle,
                self.train_set.construct()._handle))
wxchan's avatar
wxchan committed
3885
            self.__inner_predict_buffer[0] = None
3886
            self.train_set_version = self.train_set.version
wxchan's avatar
wxchan committed
3887
3888
        is_finished = ctypes.c_int(0)
        if fobj is None:
3889
            if self.__set_objective_to_none:
3890
                raise LightGBMError('Cannot update due to null objective function.')
wxchan's avatar
wxchan committed
3891
            _safe_call(_LIB.LGBM_BoosterUpdateOneIter(
3892
                self._handle,
wxchan's avatar
wxchan committed
3893
                ctypes.byref(is_finished)))
3894
            self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
wxchan's avatar
wxchan committed
3895
3896
            return is_finished.value == 1
        else:
3897
            if not self.__set_objective_to_none:
Nikita Titov's avatar
Nikita Titov committed
3898
                self.reset_parameter({"objective": "none"}).__set_objective_to_none = True
wxchan's avatar
wxchan committed
3899
3900
3901
            grad, hess = fobj(self.__inner_predict(0), self.train_set)
            return self.__boost(grad, hess)

3902
3903
3904
3905
3906
    def __boost(
        self,
        grad: np.ndarray,
        hess: np.ndarray
    ) -> bool:
3907
        """Boost Booster for one iteration with customized gradient statistics.
Nikita Titov's avatar
Nikita Titov committed
3908

Nikita Titov's avatar
Nikita Titov committed
3909
3910
        .. note::

3911
3912
            Score is returned before any transformation,
            e.g. it is raw margin instead of probability of positive class for binary task.
3913
            For multi-class task, score are numpy 2-D array of shape = [n_samples, n_classes],
3914
            and grad and hess should be returned in the same format.
3915

wxchan's avatar
wxchan committed
3916
3917
        Parameters
        ----------
3918
        grad : numpy 1-D array or numpy 2-D array (for multi-class task)
3919
3920
            The value of the first order derivative (gradient) of the loss
            with respect to the elements of score for each sample point.
3921
        hess : numpy 1-D array or numpy 2-D array (for multi-class task)
3922
3923
            The value of the second order derivative (Hessian) of the loss
            with respect to the elements of score for each sample point.
wxchan's avatar
wxchan committed
3924
3925
3926

        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
3927
3928
        is_finished : bool
            Whether the boost was successfully finished.
wxchan's avatar
wxchan committed
3929
        """
3930
3931
3932
        if self.__num_class > 1:
            grad = grad.ravel(order='F')
            hess = hess.ravel(order='F')
3933
3934
        grad = _list_to_1d_numpy(grad, dtype=np.float32, name='gradient')
        hess = _list_to_1d_numpy(hess, dtype=np.float32, name='hessian')
3935
3936
        assert grad.flags.c_contiguous
        assert hess.flags.c_contiguous
wxchan's avatar
wxchan committed
3937
        if len(grad) != len(hess):
3938
3939
            raise ValueError(f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) don't match")
        num_train_data = self.train_set.num_data()
3940
        if len(grad) != num_train_data * self.__num_class:
3941
3942
3943
            raise ValueError(
                f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) "
                f"don't match training data length ({num_train_data}) * "
3944
                f"number of models per one iteration ({self.__num_class})"
3945
            )
wxchan's avatar
wxchan committed
3946
3947
        is_finished = ctypes.c_int(0)
        _safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom(
3948
            self._handle,
wxchan's avatar
wxchan committed
3949
3950
3951
            grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
            hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
            ctypes.byref(is_finished)))
3952
        self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
wxchan's avatar
wxchan committed
3953
3954
        return is_finished.value == 1

3955
    def rollback_one_iter(self) -> "Booster":
Nikita Titov's avatar
Nikita Titov committed
3956
3957
3958
3959
3960
3961
3962
        """Rollback one iteration.

        Returns
        -------
        self : Booster
            Booster with rolled back one iteration.
        """
wxchan's avatar
wxchan committed
3963
        _safe_call(_LIB.LGBM_BoosterRollbackOneIter(
3964
            self._handle))
3965
        self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
Nikita Titov's avatar
Nikita Titov committed
3966
        return self
wxchan's avatar
wxchan committed
3967

3968
    def current_iteration(self) -> int:
3969
3970
3971
3972
3973
3974
3975
        """Get the index of the current iteration.

        Returns
        -------
        cur_iter : int
            The index of the current iteration.
        """
Guolin Ke's avatar
Guolin Ke committed
3976
        out_cur_iter = ctypes.c_int(0)
wxchan's avatar
wxchan committed
3977
        _safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
3978
            self._handle,
wxchan's avatar
wxchan committed
3979
3980
3981
            ctypes.byref(out_cur_iter)))
        return out_cur_iter.value

3982
    def num_model_per_iteration(self) -> int:
3983
3984
3985
3986
3987
3988
3989
3990
3991
        """Get number of models per iteration.

        Returns
        -------
        model_per_iter : int
            The number of models per iteration.
        """
        model_per_iter = ctypes.c_int(0)
        _safe_call(_LIB.LGBM_BoosterNumModelPerIteration(
3992
            self._handle,
3993
3994
3995
            ctypes.byref(model_per_iter)))
        return model_per_iter.value

3996
    def num_trees(self) -> int:
3997
3998
3999
4000
4001
4002
4003
4004
4005
        """Get number of weak sub-models.

        Returns
        -------
        num_trees : int
            The number of weak sub-models.
        """
        num_trees = ctypes.c_int(0)
        _safe_call(_LIB.LGBM_BoosterNumberOfTotalModel(
4006
            self._handle,
4007
4008
4009
            ctypes.byref(num_trees)))
        return num_trees.value

4010
    def upper_bound(self) -> float:
4011
4012
4013
4014
        """Get upper bound value of a model.

        Returns
        -------
4015
        upper_bound : float
4016
4017
4018
4019
            Upper bound value of the model.
        """
        ret = ctypes.c_double(0)
        _safe_call(_LIB.LGBM_BoosterGetUpperBoundValue(
4020
            self._handle,
4021
4022
4023
            ctypes.byref(ret)))
        return ret.value

4024
    def lower_bound(self) -> float:
4025
4026
4027
4028
        """Get lower bound value of a model.

        Returns
        -------
4029
        lower_bound : float
4030
4031
4032
4033
            Lower bound value of the model.
        """
        ret = ctypes.c_double(0)
        _safe_call(_LIB.LGBM_BoosterGetLowerBoundValue(
4034
            self._handle,
4035
4036
4037
            ctypes.byref(ret)))
        return ret.value

4038
4039
4040
4041
4042
4043
    def eval(
        self,
        data: Dataset,
        name: str,
        feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
    ) -> List[_LGBM_BoosterEvalMethodResultType]:
4044
        """Evaluate for data.
wxchan's avatar
wxchan committed
4045
4046
4047

        Parameters
        ----------
4048
4049
        data : Dataset
            Data for the evaluating.
4050
        name : str
4051
            Name of the data.
4052
        feval : callable, list of callable, or None, optional (default=None)
4053
            Customized evaluation function.
4054
            Each evaluation function should accept two parameters: preds, eval_data,
4055
            and return (eval_name, eval_result, is_higher_better) or list of such tuples.
4056

4057
                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
4058
                    The predicted values.
4059
                    For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
4060
                    If custom objective function is used, predicted values are returned before any transformation,
4061
                    e.g. they are raw margin instead of probability of positive class for binary task in this case.
4062
                eval_data : Dataset
4063
                    A ``Dataset`` to evaluate.
4064
                eval_name : str
Andrew Ziem's avatar
Andrew Ziem committed
4065
                    The name of evaluation function (without whitespace).
4066
4067
4068
4069
4070
                eval_result : float
                    The eval result.
                is_higher_better : bool
                    Is eval result higher better, e.g. AUC is ``is_higher_better``.

wxchan's avatar
wxchan committed
4071
4072
        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4073
        result : list
4074
            List with (dataset_name, eval_name, eval_result, is_higher_better) tuples.
wxchan's avatar
wxchan committed
4075
        """
Guolin Ke's avatar
Guolin Ke committed
4076
4077
        if not isinstance(data, Dataset):
            raise TypeError("Can only eval for Dataset instance")
wxchan's avatar
wxchan committed
4078
4079
4080
4081
        data_idx = -1
        if data is self.train_set:
            data_idx = 0
        else:
4082
            for i in range(len(self.valid_sets)):
wxchan's avatar
wxchan committed
4083
4084
4085
                if data is self.valid_sets[i]:
                    data_idx = i + 1
                    break
4086
        # need to push new valid data
wxchan's avatar
wxchan committed
4087
4088
4089
4090
4091
4092
        if data_idx == -1:
            self.add_valid(data, name)
            data_idx = self.__num_dataset - 1

        return self.__inner_eval(name, data_idx, feval)

4093
4094
4095
4096
    def eval_train(
        self,
        feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
    ) -> List[_LGBM_BoosterEvalMethodResultType]:
4097
        """Evaluate for training data.
wxchan's avatar
wxchan committed
4098
4099
4100

        Parameters
        ----------
4101
        feval : callable, list of callable, or None, optional (default=None)
4102
            Customized evaluation function.
4103
            Each evaluation function should accept two parameters: preds, eval_data,
4104
            and return (eval_name, eval_result, is_higher_better) or list of such tuples.
4105

4106
                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
4107
                    The predicted values.
4108
                    For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
4109
                    If custom objective function is used, predicted values are returned before any transformation,
4110
                    e.g. they are raw margin instead of probability of positive class for binary task in this case.
Akshita Dixit's avatar
Akshita Dixit committed
4111
                eval_data : Dataset
4112
                    The training dataset.
4113
                eval_name : str
Andrew Ziem's avatar
Andrew Ziem committed
4114
                    The name of evaluation function (without whitespace).
4115
4116
4117
4118
4119
                eval_result : float
                    The eval result.
                is_higher_better : bool
                    Is eval result higher better, e.g. AUC is ``is_higher_better``.

wxchan's avatar
wxchan committed
4120
4121
        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4122
        result : list
4123
            List with (train_dataset_name, eval_name, eval_result, is_higher_better) tuples.
wxchan's avatar
wxchan committed
4124
        """
4125
        return self.__inner_eval(self._train_data_name, 0, feval)
wxchan's avatar
wxchan committed
4126

4127
4128
4129
4130
    def eval_valid(
        self,
        feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None
    ) -> List[_LGBM_BoosterEvalMethodResultType]:
4131
        """Evaluate for validation data.
wxchan's avatar
wxchan committed
4132
4133
4134

        Parameters
        ----------
4135
        feval : callable, list of callable, or None, optional (default=None)
4136
            Customized evaluation function.
4137
            Each evaluation function should accept two parameters: preds, eval_data,
4138
            and return (eval_name, eval_result, is_higher_better) or list of such tuples.
4139

4140
                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
4141
                    The predicted values.
4142
                    For multi-class task, preds are numpy 2-D array of shape = [n_samples, n_classes].
4143
                    If custom objective function is used, predicted values are returned before any transformation,
4144
                    e.g. they are raw margin instead of probability of positive class for binary task in this case.
Akshita Dixit's avatar
Akshita Dixit committed
4145
                eval_data : Dataset
4146
                    The validation dataset.
4147
                eval_name : str
Andrew Ziem's avatar
Andrew Ziem committed
4148
                    The name of evaluation function (without whitespace).
4149
4150
4151
4152
4153
                eval_result : float
                    The eval result.
                is_higher_better : bool
                    Is eval result higher better, e.g. AUC is ``is_higher_better``.

wxchan's avatar
wxchan committed
4154
4155
        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4156
        result : list
4157
            List with (validation_dataset_name, eval_name, eval_result, is_higher_better) tuples.
wxchan's avatar
wxchan committed
4158
        """
4159
        return [item for i in range(1, self.__num_dataset)
wxchan's avatar
wxchan committed
4160
                for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]
wxchan's avatar
wxchan committed
4161

4162
4163
4164
4165
4166
4167
4168
    def save_model(
        self,
        filename: Union[str, Path],
        num_iteration: Optional[int] = None,
        start_iteration: int = 0,
        importance_type: str = 'split'
    ) -> "Booster":
4169
        """Save Booster to file.
wxchan's avatar
wxchan committed
4170
4171
4172

        Parameters
        ----------
4173
        filename : str or pathlib.Path
4174
            Filename to save Booster.
4175
4176
4177
4178
        num_iteration : int or None, optional (default=None)
            Index of the iteration that should be saved.
            If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
            If <= 0, all iterations are saved.
Nikita Titov's avatar
Nikita Titov committed
4179
        start_iteration : int, optional (default=0)
4180
            Start index of the iteration that should be saved.
4181
        importance_type : str, optional (default="split")
4182
4183
4184
            What type of feature importance should be saved.
            If "split", result contains numbers of times the feature is used in a model.
            If "gain", result contains total gains of splits which use the feature.
Nikita Titov's avatar
Nikita Titov committed
4185
4186
4187
4188
4189

        Returns
        -------
        self : Booster
            Returns self.
wxchan's avatar
wxchan committed
4190
        """
4191
        if num_iteration is None:
4192
            num_iteration = self.best_iteration
4193
        importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
wxchan's avatar
wxchan committed
4194
        _safe_call(_LIB.LGBM_BoosterSaveModel(
4195
            self._handle,
4196
            ctypes.c_int(start_iteration),
Guolin Ke's avatar
Guolin Ke committed
4197
            ctypes.c_int(num_iteration),
4198
            ctypes.c_int(importance_type_int),
4199
            _c_str(str(filename))))
4200
        _dump_pandas_categorical(self.pandas_categorical, filename)
Nikita Titov's avatar
Nikita Titov committed
4201
        return self
wxchan's avatar
wxchan committed
4202

4203
4204
4205
4206
4207
    def shuffle_models(
        self,
        start_iteration: int = 0,
        end_iteration: int = -1
    ) -> "Booster":
4208
        """Shuffle models.
Nikita Titov's avatar
Nikita Titov committed
4209

4210
4211
4212
        Parameters
        ----------
        start_iteration : int, optional (default=0)
4213
            The first iteration that will be shuffled.
4214
4215
        end_iteration : int, optional (default=-1)
            The last iteration that will be shuffled.
4216
            If <= 0, means the last available iteration.
4217

Nikita Titov's avatar
Nikita Titov committed
4218
4219
4220
4221
        Returns
        -------
        self : Booster
            Booster with shuffled models.
4222
        """
4223
        _safe_call(_LIB.LGBM_BoosterShuffleModels(
4224
            self._handle,
Guolin Ke's avatar
Guolin Ke committed
4225
4226
            ctypes.c_int(start_iteration),
            ctypes.c_int(end_iteration)))
Nikita Titov's avatar
Nikita Titov committed
4227
        return self
4228

4229
    def model_from_string(self, model_str: str) -> "Booster":
4230
4231
4232
4233
        """Load Booster from a string.

        Parameters
        ----------
4234
        model_str : str
4235
4236
4237
4238
            Model will be loaded from this string.

        Returns
        -------
Nikita Titov's avatar
Nikita Titov committed
4239
        self : Booster
4240
4241
            Loaded Booster object.
        """
4242
4243
4244
        # ensure that existing Booster is freed before replacing it
        # with a new one createdfrom file
        _safe_call(_LIB.LGBM_BoosterFree(self._handle))
4245
        self._free_buffer()
4246
        self._handle = ctypes.c_void_p()
4247
4248
        out_num_iterations = ctypes.c_int(0)
        _safe_call(_LIB.LGBM_BoosterLoadModelFromString(
4249
            _c_str(model_str),
4250
            ctypes.byref(out_num_iterations),
4251
            ctypes.byref(self._handle)))
4252
4253
        out_num_class = ctypes.c_int(0)
        _safe_call(_LIB.LGBM_BoosterGetNumClasses(
4254
            self._handle,
4255
4256
            ctypes.byref(out_num_class)))
        self.__num_class = out_num_class.value
4257
        self.pandas_categorical = _load_pandas_categorical(model_str=model_str)
4258
4259
        return self

4260
4261
4262
4263
4264
4265
    def model_to_string(
        self,
        num_iteration: Optional[int] = None,
        start_iteration: int = 0,
        importance_type: str = 'split'
    ) -> str:
4266
        """Save Booster to string.
4267

4268
4269
4270
4271
4272
4273
        Parameters
        ----------
        num_iteration : int or None, optional (default=None)
            Index of the iteration that should be saved.
            If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
            If <= 0, all iterations are saved.
Nikita Titov's avatar
Nikita Titov committed
4274
        start_iteration : int, optional (default=0)
4275
            Start index of the iteration that should be saved.
4276
        importance_type : str, optional (default="split")
4277
4278
4279
            What type of feature importance should be saved.
            If "split", result contains numbers of times the feature is used in a model.
            If "gain", result contains total gains of splits which use the feature.
4280
4281
4282

        Returns
        -------
4283
        str_repr : str
4284
4285
            String representation of Booster.
        """
4286
        if num_iteration is None:
4287
            num_iteration = self.best_iteration
4288
        importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
4289
        buffer_len = 1 << 20
4290
        tmp_out_len = ctypes.c_int64(0)
4291
        string_buffer = ctypes.create_string_buffer(buffer_len)
4292
        ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
4293
        _safe_call(_LIB.LGBM_BoosterSaveModelToString(
4294
            self._handle,
4295
            ctypes.c_int(start_iteration),
4296
            ctypes.c_int(num_iteration),
4297
            ctypes.c_int(importance_type_int),
4298
            ctypes.c_int64(buffer_len),
4299
4300
4301
            ctypes.byref(tmp_out_len),
            ptr_string_buffer))
        actual_len = tmp_out_len.value
4302
        # if buffer length is not long enough, re-allocate a buffer
4303
4304
        if actual_len > buffer_len:
            string_buffer = ctypes.create_string_buffer(actual_len)
4305
            ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
4306
            _safe_call(_LIB.LGBM_BoosterSaveModelToString(
4307
                self._handle,
4308
                ctypes.c_int(start_iteration),
4309
                ctypes.c_int(num_iteration),
4310
                ctypes.c_int(importance_type_int),
4311
                ctypes.c_int64(actual_len),
4312
4313
                ctypes.byref(tmp_out_len),
                ptr_string_buffer))
4314
        ret = string_buffer.value.decode('utf-8')
4315
4316
        ret += _dump_pandas_categorical(self.pandas_categorical)
        return ret
4317

4318
4319
4320
4321
4322
4323
4324
    def dump_model(
        self,
        num_iteration: Optional[int] = None,
        start_iteration: int = 0,
        importance_type: str = 'split',
        object_hook: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None
    ) -> Dict[str, Any]:
Nikita Titov's avatar
Nikita Titov committed
4325
        """Dump Booster to JSON format.
wxchan's avatar
wxchan committed
4326

4327
4328
        Parameters
        ----------
4329
4330
4331
4332
        num_iteration : int or None, optional (default=None)
            Index of the iteration that should be dumped.
            If None, if the best iteration exists, it is dumped; otherwise, all iterations are dumped.
            If <= 0, all iterations are dumped.
Nikita Titov's avatar
Nikita Titov committed
4333
        start_iteration : int, optional (default=0)
4334
            Start index of the iteration that should be dumped.
4335
        importance_type : str, optional (default="split")
4336
4337
4338
            What type of feature importance should be dumped.
            If "split", result contains numbers of times the feature is used in a model.
            If "gain", result contains total gains of splits which use the feature.
4339
4340
4341
4342
4343
4344
4345
4346
4347
        object_hook : callable or None, optional (default=None)
            If not None, ``object_hook`` is a function called while parsing the json
            string returned by the C API. It may be used to alter the json, to store
            specific values while building the json structure. It avoids
            walking through the structure again. It saves a significant amount
            of time if the number of trees is huge.
            Signature is ``def object_hook(node: dict) -> dict``.
            None is equivalent to ``lambda node: node``.
            See documentation of ``json.loads()`` for further details.
4348

wxchan's avatar
wxchan committed
4349
4350
        Returns
        -------
4351
        json_repr : dict
Nikita Titov's avatar
Nikita Titov committed
4352
            JSON format of Booster.
wxchan's avatar
wxchan committed
4353
        """
4354
        if num_iteration is None:
4355
            num_iteration = self.best_iteration
4356
        importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
wxchan's avatar
wxchan committed
4357
        buffer_len = 1 << 20
4358
        tmp_out_len = ctypes.c_int64(0)
wxchan's avatar
wxchan committed
4359
        string_buffer = ctypes.create_string_buffer(buffer_len)
4360
        ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
wxchan's avatar
wxchan committed
4361
        _safe_call(_LIB.LGBM_BoosterDumpModel(
4362
            self._handle,
4363
            ctypes.c_int(start_iteration),
Guolin Ke's avatar
Guolin Ke committed
4364
            ctypes.c_int(num_iteration),
4365
            ctypes.c_int(importance_type_int),
4366
            ctypes.c_int64(buffer_len),
wxchan's avatar
wxchan committed
4367
            ctypes.byref(tmp_out_len),
Guolin Ke's avatar
Guolin Ke committed
4368
            ptr_string_buffer))
wxchan's avatar
wxchan committed
4369
        actual_len = tmp_out_len.value
4370
        # if buffer length is not long enough, reallocate a buffer
wxchan's avatar
wxchan committed
4371
4372
        if actual_len > buffer_len:
            string_buffer = ctypes.create_string_buffer(actual_len)
4373
            ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
wxchan's avatar
wxchan committed
4374
            _safe_call(_LIB.LGBM_BoosterDumpModel(
4375
                self._handle,
4376
                ctypes.c_int(start_iteration),
Guolin Ke's avatar
Guolin Ke committed
4377
                ctypes.c_int(num_iteration),
4378
                ctypes.c_int(importance_type_int),
4379
                ctypes.c_int64(actual_len),
wxchan's avatar
wxchan committed
4380
                ctypes.byref(tmp_out_len),
Guolin Ke's avatar
Guolin Ke committed
4381
                ptr_string_buffer))
4382
        ret = json.loads(string_buffer.value.decode('utf-8'), object_hook=object_hook)
4383
        ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical,
4384
                                                          default=_json_default_with_numpy))
4385
        return ret
wxchan's avatar
wxchan committed
4386

4387
4388
    def predict(
        self,
4389
        data: _LGBM_PredictDataType,
4390
4391
4392
4393
4394
4395
4396
4397
        start_iteration: int = 0,
        num_iteration: Optional[int] = None,
        raw_score: bool = False,
        pred_leaf: bool = False,
        pred_contrib: bool = False,
        data_has_header: bool = False,
        validate_features: bool = False,
        **kwargs: Any
4398
    ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
4399
        """Make a prediction.
wxchan's avatar
wxchan committed
4400
4401
4402

        Parameters
        ----------
4403
        data : str, pathlib.Path, numpy array, pandas DataFrame, pyarrow Table, H2O DataTable's Frame or scipy.sparse
4404
            Data source for prediction.
4405
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
4406
        start_iteration : int, optional (default=0)
4407
            Start index of the iteration to predict.
4408
            If <= 0, starts from the first iteration.
4409
        num_iteration : int or None, optional (default=None)
4410
4411
4412
4413
            Total number of iterations used in the prediction.
            If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
            otherwise, all iterations from ``start_iteration`` are used (no limits).
            If <= 0, all iterations from ``start_iteration`` are used (no limits).
4414
4415
4416
4417
        raw_score : bool, optional (default=False)
            Whether to predict raw scores.
        pred_leaf : bool, optional (default=False)
            Whether to predict leaf index.
4418
4419
        pred_contrib : bool, optional (default=False)
            Whether to predict feature contributions.
4420

Nikita Titov's avatar
Nikita Titov committed
4421
4422
4423
4424
4425
4426
4427
            .. note::

                If you want to get more explanations for your model's predictions using SHAP values,
                like SHAP interaction values,
                you can install the shap package (https://github.com/slundberg/shap).
                Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra
                column, where the last column is the expected value.
4428

4429
4430
        data_has_header : bool, optional (default=False)
            Whether the data has header.
4431
            Used only if data is str.
4432
4433
4434
        validate_features : bool, optional (default=False)
            If True, ensure that the features used to predict match the ones used to train.
            Used only if data is pandas DataFrame.
4435
4436
        **kwargs
            Other parameters for the prediction.
wxchan's avatar
wxchan committed
4437
4438
4439

        Returns
        -------
4440
        result : numpy array, scipy.sparse or list of scipy.sparse
4441
            Prediction result.
4442
            Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
wxchan's avatar
wxchan committed
4443
        """
4444
4445
4446
4447
        predictor = _InnerPredictor.from_booster(
            booster=self,
            pred_parameter=deepcopy(kwargs),
        )
4448
        if num_iteration is None:
4449
            if start_iteration <= 0:
4450
4451
4452
                num_iteration = self.best_iteration
            else:
                num_iteration = -1
4453
4454
4455
4456
4457
4458
4459
4460
4461
4462
        return predictor.predict(
            data=data,
            start_iteration=start_iteration,
            num_iteration=num_iteration,
            raw_score=raw_score,
            pred_leaf=pred_leaf,
            pred_contrib=pred_contrib,
            data_has_header=data_has_header,
            validate_features=validate_features
        )
wxchan's avatar
wxchan committed
4463

4464
4465
    def refit(
        self,
4466
        data: _LGBM_TrainDataType,
4467
        label: _LGBM_LabelType,
4468
4469
        decay_rate: float = 0.9,
        reference: Optional[Dataset] = None,
4470
4471
4472
        weight: Optional[_LGBM_WeightType] = None,
        group: Optional[_LGBM_GroupType] = None,
        init_score: Optional[_LGBM_InitScoreType] = None,
4473
4474
        feature_name: _LGBM_FeatureNameConfiguration = 'auto',
        categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
4475
4476
4477
        dataset_params: Optional[Dict[str, Any]] = None,
        free_raw_data: bool = True,
        validate_features: bool = False,
4478
        **kwargs
4479
    ) -> "Booster":
Guolin Ke's avatar
Guolin Ke committed
4480
4481
4482
4483
        """Refit the existing Booster by new data.

        Parameters
        ----------
4484
        data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Guolin Ke's avatar
Guolin Ke committed
4485
            Data source for refit.
4486
            If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
4487
        label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or pyarrow ChunkedArray
Guolin Ke's avatar
Guolin Ke committed
4488
4489
            Label for refit.
        decay_rate : float, optional (default=0.9)
4490
4491
            Decay rate of refit,
            will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
4492
4493
        reference : Dataset or None, optional (default=None)
            Reference for ``data``.
4494
4495
4496

            .. versionadded:: 4.0.0

4497
        weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
4498
            Weight for each ``data`` instance. Weights should be non-negative.
4499
4500
4501

            .. versionadded:: 4.0.0

4502
        group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
4503
4504
4505
4506
4507
            Group/query size for ``data``.
            Only used in the learning-to-rank task.
            sum(group) = n_samples.
            For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
            where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
4508
4509
4510

            .. versionadded:: 4.0.0

4511
        init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray, pyarrow Table (for multi-class task) or None, optional (default=None)
4512
            Init score for ``data``.
4513
4514
4515

            .. versionadded:: 4.0.0

4516
4517
4518
        feature_name : list of str, or 'auto', optional (default="auto")
            Feature names for ``data``.
            If 'auto' and data is pandas DataFrame, data columns names are used.
4519
4520
4521

            .. versionadded:: 4.0.0

4522
4523
4524
4525
4526
        categorical_feature : list of str or int, or 'auto', optional (default="auto")
            Categorical features for ``data``.
            If list of int, interpreted as indices.
            If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
            If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
4527
            All values in categorical features will be cast to int32 and thus should be less than int32 max value (2147483647).
4528
4529
4530
            Large values could be memory consuming. Consider using consecutive integers starting from zero.
            All negative values in categorical features will be treated as missing values.
            The output cannot be monotonically constrained with respect to a categorical feature.
4531
            Floating point numbers in categorical features will be rounded towards 0.
4532
4533
4534

            .. versionadded:: 4.0.0

4535
4536
        dataset_params : dict or None, optional (default=None)
            Other parameters for Dataset ``data``.
4537
4538
4539

            .. versionadded:: 4.0.0

4540
4541
        free_raw_data : bool, optional (default=True)
            If True, raw data is freed after constructing inner Dataset for ``data``.
4542
4543
4544

            .. versionadded:: 4.0.0

4545
4546
4547
        validate_features : bool, optional (default=False)
            If True, ensure that the features used to refit the model match the original ones.
            Used only if data is pandas DataFrame.
4548
4549
4550

            .. versionadded:: 4.0.0

4551
4552
        **kwargs
            Other parameters for refit.
4553
            These parameters will be passed to ``predict`` method.
Guolin Ke's avatar
Guolin Ke committed
4554
4555
4556
4557
4558
4559

        Returns
        -------
        result : Booster
            Refitted Booster.
        """
4560
4561
        if self.__set_objective_to_none:
            raise LightGBMError('Cannot refit due to null objective function.')
4562
4563
        if dataset_params is None:
            dataset_params = {}
4564
4565
4566
4567
        predictor = _InnerPredictor.from_booster(
            booster=self,
            pred_parameter=deepcopy(kwargs)
        )
4568
        leaf_preds: np.ndarray = predictor.predict(  # type: ignore[assignment]
4569
4570
4571
4572
4573
            data=data,
            start_iteration=-1,
            pred_leaf=True,
            validate_features=validate_features
        )
4574
        nrow, ncol = leaf_preds.shape
4575
        out_is_linear = ctypes.c_int(0)
4576
        _safe_call(_LIB.LGBM_BoosterGetLinear(
4577
            self._handle,
4578
            ctypes.byref(out_is_linear)))
Nikita Titov's avatar
Nikita Titov committed
4579
4580
4581
4582
4583
        new_params = _choose_param_value(
            main_param_name="linear_tree",
            params=self.params,
            default_value=None
        )
4584
        new_params["linear_tree"] = bool(out_is_linear.value)
4585
4586
4587
4588
4589
4590
4591
4592
4593
4594
4595
4596
4597
        new_params.update(dataset_params)
        train_set = Dataset(
            data=data,
            label=label,
            reference=reference,
            weight=weight,
            group=group,
            init_score=init_score,
            feature_name=feature_name,
            categorical_feature=categorical_feature,
            params=new_params,
            free_raw_data=free_raw_data,
        )
4598
        new_params['refit_decay_rate'] = decay_rate
4599
        new_booster = Booster(new_params, train_set)
Guolin Ke's avatar
Guolin Ke committed
4600
4601
        # Copy models
        _safe_call(_LIB.LGBM_BoosterMerge(
4602
4603
            new_booster._handle,
            predictor._handle))
Guolin Ke's avatar
Guolin Ke committed
4604
        leaf_preds = leaf_preds.reshape(-1)
4605
        ptr_data, _, _ = _c_int_array(leaf_preds)
Guolin Ke's avatar
Guolin Ke committed
4606
        _safe_call(_LIB.LGBM_BoosterRefit(
4607
            new_booster._handle,
Guolin Ke's avatar
Guolin Ke committed
4608
            ptr_data,
4609
4610
            ctypes.c_int32(nrow),
            ctypes.c_int32(ncol)))
4611
        new_booster._network = self._network
Guolin Ke's avatar
Guolin Ke committed
4612
4613
        return new_booster

4614
    def get_leaf_output(self, tree_id: int, leaf_id: int) -> float:
4615
4616
4617
4618
4619
4620
4621
4622
4623
4624
4625
4626
4627
4628
        """Get the output of a leaf.

        Parameters
        ----------
        tree_id : int
            The index of the tree.
        leaf_id : int
            The index of the leaf in the tree.

        Returns
        -------
        result : float
            The output of the leaf.
        """
4629
4630
        ret = ctypes.c_double(0)
        _safe_call(_LIB.LGBM_BoosterGetLeafValue(
4631
            self._handle,
4632
4633
4634
4635
4636
            ctypes.c_int(tree_id),
            ctypes.c_int(leaf_id),
            ctypes.byref(ret)))
        return ret.value

4637
4638
4639
4640
4641
4642
4643
4644
    def set_leaf_output(
        self,
        tree_id: int,
        leaf_id: int,
        value: float,
    ) -> 'Booster':
        """Set the output of a leaf.

4645
4646
        .. versionadded:: 4.0.0

4647
4648
4649
4650
4651
4652
4653
4654
4655
4656
4657
4658
4659
4660
4661
4662
        Parameters
        ----------
        tree_id : int
            The index of the tree.
        leaf_id : int
            The index of the leaf in the tree.
        value : float
            Value to set as the output of the leaf.

        Returns
        -------
        self : Booster
            Booster with the leaf output set.
        """
        _safe_call(
            _LIB.LGBM_BoosterSetLeafValue(
4663
                self._handle,
4664
4665
4666
4667
4668
4669
4670
                ctypes.c_int(tree_id),
                ctypes.c_int(leaf_id),
                ctypes.c_double(value)
            )
        )
        return self

4671
    def num_feature(self) -> int:
4672
4673
4674
4675
4676
4677
4678
        """Get number of features.

        Returns
        -------
        num_feature : int
            The number of features.
        """
4679
4680
        out_num_feature = ctypes.c_int(0)
        _safe_call(_LIB.LGBM_BoosterGetNumFeature(
4681
            self._handle,
4682
4683
4684
            ctypes.byref(out_num_feature)))
        return out_num_feature.value

4685
    def feature_name(self) -> List[str]:
4686
        """Get names of features.
wxchan's avatar
wxchan committed
4687
4688
4689

        Returns
        -------
4690
        result : list of str
4691
            List with names of features.
wxchan's avatar
wxchan committed
4692
        """
4693
        num_feature = self.num_feature()
4694
        # Get name of features
wxchan's avatar
wxchan committed
4695
        tmp_out_len = ctypes.c_int(0)
4696
4697
        reserved_string_buffer_size = 255
        required_string_buffer_size = ctypes.c_size_t(0)
4698
        string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
4699
        ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
wxchan's avatar
wxchan committed
4700
        _safe_call(_LIB.LGBM_BoosterGetFeatureNames(
4701
            self._handle,
4702
            ctypes.c_int(num_feature),
wxchan's avatar
wxchan committed
4703
            ctypes.byref(tmp_out_len),
4704
            ctypes.c_size_t(reserved_string_buffer_size),
4705
            ctypes.byref(required_string_buffer_size),
wxchan's avatar
wxchan committed
4706
4707
4708
            ptr_string_buffers))
        if num_feature != tmp_out_len.value:
            raise ValueError("Length of feature names doesn't equal with num_feature")
4709
4710
4711
4712
        actual_string_buffer_size = required_string_buffer_size.value
        # if buffer length is not long enough, reallocate buffers
        if reserved_string_buffer_size < actual_string_buffer_size:
            string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
4713
            ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
4714
            _safe_call(_LIB.LGBM_BoosterGetFeatureNames(
4715
                self._handle,
4716
4717
4718
4719
4720
                ctypes.c_int(num_feature),
                ctypes.byref(tmp_out_len),
                ctypes.c_size_t(actual_string_buffer_size),
                ctypes.byref(required_string_buffer_size),
                ptr_string_buffers))
4721
        return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]
wxchan's avatar
wxchan committed
4722

4723
4724
4725
4726
4727
    def feature_importance(
        self,
        importance_type: str = 'split',
        iteration: Optional[int] = None
    ) -> np.ndarray:
4728
        """Get feature importances.
4729

4730
4731
        Parameters
        ----------
4732
        importance_type : str, optional (default="split")
4733
4734
4735
            How the importance is calculated.
            If "split", result contains numbers of times the feature is used in a model.
            If "gain", result contains total gains of splits which use the feature.
4736
4737
4738
4739
        iteration : int or None, optional (default=None)
            Limit number of iterations in the feature importance calculation.
            If None, if the best iteration exists, it is used; otherwise, all trees are used.
            If <= 0, all trees are used (no limits).
4740

4741
4742
        Returns
        -------
4743
4744
        result : numpy array
            Array with feature importances.
4745
        """
4746
4747
        if iteration is None:
            iteration = self.best_iteration
4748
        importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
4749
        result = np.empty(self.num_feature(), dtype=np.float64)
4750
        _safe_call(_LIB.LGBM_BoosterFeatureImportance(
4751
            self._handle,
4752
4753
4754
            ctypes.c_int(iteration),
            ctypes.c_int(importance_type_int),
            result.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
4755
        if importance_type_int == _C_API_FEATURE_IMPORTANCE_SPLIT:
4756
            return result.astype(np.int32)
4757
4758
        else:
            return result
4759

4760
4761
4762
4763
4764
4765
    def get_split_value_histogram(
        self,
        feature: Union[int, str],
        bins: Optional[Union[int, str]] = None,
        xgboost_style: bool = False
    ) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray, pd_DataFrame]:
4766
4767
4768
4769
        """Get split value histogram for the specified feature.

        Parameters
        ----------
4770
        feature : int or str
4771
4772
            The feature name or index the histogram is calculated for.
            If int, interpreted as index.
4773
            If str, interpreted as name.
4774

Nikita Titov's avatar
Nikita Titov committed
4775
4776
4777
            .. warning::

                Categorical features are not supported.
4778

4779
        bins : int, str or None, optional (default=None)
4780
4781
4782
            The maximum number of bins.
            If None, or int and > number of unique split values and ``xgboost_style=True``,
            the number of bins equals number of unique split values.
4783
            If str, it should be one from the list of the supported values by ``numpy.histogram()`` function.
4784
4785
4786
4787
4788
4789
4790
4791
4792
4793
4794
4795
4796
4797
        xgboost_style : bool, optional (default=False)
            Whether the returned result should be in the same form as it is in XGBoost.
            If False, the returned value is tuple of 2 numpy arrays as it is in ``numpy.histogram()`` function.
            If True, the returned value is matrix, in which the first column is the right edges of non-empty bins
            and the second one is the histogram values.

        Returns
        -------
        result_tuple : tuple of 2 numpy arrays
            If ``xgboost_style=False``, the values of the histogram of used splitting values for the specified feature
            and the bin edges.
        result_array_like : numpy array or pandas DataFrame (if pandas is installed)
            If ``xgboost_style=True``, the histogram of used splitting values for the specified feature.
        """
4798
        def add(root: Dict[str, Any]) -> None:
4799
4800
            """Recursively add thresholds."""
            if 'split_index' in root:  # non-leaf
4801
                if feature_names is not None and isinstance(feature, str):
4802
4803
4804
4805
                    split_feature = feature_names[root['split_feature']]
                else:
                    split_feature = root['split_feature']
                if split_feature == feature:
4806
                    if isinstance(root['threshold'], str):
4807
4808
4809
                        raise LightGBMError('Cannot compute split value histogram for the categorical feature')
                    else:
                        values.append(root['threshold'])
4810
4811
4812
4813
4814
4815
                add(root['left_child'])
                add(root['right_child'])

        model = self.dump_model()
        feature_names = model.get('feature_names')
        tree_infos = model['tree_info']
4816
        values: List[float] = []
4817
4818
4819
        for tree_info in tree_infos:
            add(tree_info['tree_structure'])

4820
        if bins is None or isinstance(bins, int) and xgboost_style:
4821
4822
4823
4824
4825
4826
4827
            n_unique = len(np.unique(values))
            bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
        hist, bin_edges = np.histogram(values, bins=bins)
        if xgboost_style:
            ret = np.column_stack((bin_edges[1:], hist))
            ret = ret[ret[:, 1] > 0]
            if PANDAS_INSTALLED:
4828
                return pd_DataFrame(ret, columns=['SplitValue', 'Count'])
4829
4830
4831
4832
4833
            else:
                return ret
        else:
            return hist, bin_edges

4834
4835
4836
4837
    def __inner_eval(
        self,
        data_name: str,
        data_idx: int,
4838
        feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]]
4839
    ) -> List[_LGBM_BoosterEvalMethodResultType]:
4840
        """Evaluate training or validation data."""
wxchan's avatar
wxchan committed
4841
        if data_idx >= self.__num_dataset:
4842
            raise ValueError("Data_idx should be smaller than number of dataset")
wxchan's avatar
wxchan committed
4843
4844
4845
        self.__get_eval_info()
        ret = []
        if self.__num_inner_eval > 0:
4846
            result = np.empty(self.__num_inner_eval, dtype=np.float64)
Guolin Ke's avatar
Guolin Ke committed
4847
            tmp_out_len = ctypes.c_int(0)
wxchan's avatar
wxchan committed
4848
            _safe_call(_LIB.LGBM_BoosterGetEval(
4849
                self._handle,
Guolin Ke's avatar
Guolin Ke committed
4850
                ctypes.c_int(data_idx),
wxchan's avatar
wxchan committed
4851
                ctypes.byref(tmp_out_len),
Guolin Ke's avatar
Guolin Ke committed
4852
                result.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
wxchan's avatar
wxchan committed
4853
            if tmp_out_len.value != self.__num_inner_eval:
4854
                raise ValueError("Wrong length of eval results")
4855
            for i in range(self.__num_inner_eval):
4856
4857
                ret.append((data_name, self.__name_inner_eval[i],
                            result[i], self.__higher_better_inner_eval[i]))
4858
4859
        if callable(feval):
            feval = [feval]
wxchan's avatar
wxchan committed
4860
4861
4862
4863
4864
        if feval is not None:
            if data_idx == 0:
                cur_data = self.train_set
            else:
                cur_data = self.valid_sets[data_idx - 1]
4865
4866
4867
4868
4869
4870
4871
4872
4873
            for eval_function in feval:
                if eval_function is None:
                    continue
                feval_ret = eval_function(self.__inner_predict(data_idx), cur_data)
                if isinstance(feval_ret, list):
                    for eval_name, val, is_higher_better in feval_ret:
                        ret.append((data_name, eval_name, val, is_higher_better))
                else:
                    eval_name, val, is_higher_better = feval_ret
wxchan's avatar
wxchan committed
4874
4875
4876
                    ret.append((data_name, eval_name, val, is_higher_better))
        return ret

4877
    def __inner_predict(self, data_idx: int) -> np.ndarray:
4878
        """Predict for training and validation dataset."""
wxchan's avatar
wxchan committed
4879
        if data_idx >= self.__num_dataset:
4880
            raise ValueError("Data_idx should be smaller than number of dataset")
wxchan's avatar
wxchan committed
4881
4882
4883
4884
4885
        if self.__inner_predict_buffer[data_idx] is None:
            if data_idx == 0:
                n_preds = self.train_set.num_data() * self.__num_class
            else:
                n_preds = self.valid_sets[data_idx - 1].num_data() * self.__num_class
4886
            self.__inner_predict_buffer[data_idx] = np.empty(n_preds, dtype=np.float64)
4887
        # avoid to predict many time in one iteration
wxchan's avatar
wxchan committed
4888
4889
        if not self.__is_predicted_cur_iter[data_idx]:
            tmp_out_len = ctypes.c_int64(0)
4890
            data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))  # type: ignore[union-attr]
wxchan's avatar
wxchan committed
4891
            _safe_call(_LIB.LGBM_BoosterGetPredict(
4892
                self._handle,
Guolin Ke's avatar
Guolin Ke committed
4893
                ctypes.c_int(data_idx),
wxchan's avatar
wxchan committed
4894
4895
                ctypes.byref(tmp_out_len),
                data_ptr))
4896
            if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):  # type: ignore[arg-type]
4897
                raise ValueError(f"Wrong length of predict results for data {data_idx}")
wxchan's avatar
wxchan committed
4898
            self.__is_predicted_cur_iter[data_idx] = True
4899
        result: np.ndarray = self.__inner_predict_buffer[data_idx]  # type: ignore[assignment]
4900
4901
4902
4903
        if self.__num_class > 1:
            num_data = result.size // self.__num_class
            result = result.reshape(num_data, self.__num_class, order='F')
        return result
wxchan's avatar
wxchan committed
4904

4905
    def __get_eval_info(self) -> None:
4906
        """Get inner evaluation count and names."""
wxchan's avatar
wxchan committed
4907
4908
        if self.__need_reload_eval_info:
            self.__need_reload_eval_info = False
Guolin Ke's avatar
Guolin Ke committed
4909
            out_num_eval = ctypes.c_int(0)
4910
            # Get num of inner evals
wxchan's avatar
wxchan committed
4911
            _safe_call(_LIB.LGBM_BoosterGetEvalCounts(
4912
                self._handle,
wxchan's avatar
wxchan committed
4913
4914
4915
                ctypes.byref(out_num_eval)))
            self.__num_inner_eval = out_num_eval.value
            if self.__num_inner_eval > 0:
4916
                # Get name of eval metrics
Guolin Ke's avatar
Guolin Ke committed
4917
                tmp_out_len = ctypes.c_int(0)
4918
4919
4920
                reserved_string_buffer_size = 255
                required_string_buffer_size = ctypes.c_size_t(0)
                string_buffers = [
4921
                    ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(self.__num_inner_eval)
4922
                ]
4923
                ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
wxchan's avatar
wxchan committed
4924
                _safe_call(_LIB.LGBM_BoosterGetEvalNames(
4925
                    self._handle,
4926
                    ctypes.c_int(self.__num_inner_eval),
wxchan's avatar
wxchan committed
4927
                    ctypes.byref(tmp_out_len),
4928
                    ctypes.c_size_t(reserved_string_buffer_size),
4929
                    ctypes.byref(required_string_buffer_size),
wxchan's avatar
wxchan committed
4930
4931
                    ptr_string_buffers))
                if self.__num_inner_eval != tmp_out_len.value:
4932
                    raise ValueError("Length of eval names doesn't equal with num_evals")
4933
4934
4935
4936
4937
4938
                actual_string_buffer_size = required_string_buffer_size.value
                # if buffer length is not long enough, reallocate buffers
                if reserved_string_buffer_size < actual_string_buffer_size:
                    string_buffers = [
                        ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(self.__num_inner_eval)
                    ]
4939
                    ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))  # type: ignore[misc]
4940
                    _safe_call(_LIB.LGBM_BoosterGetEvalNames(
4941
                        self._handle,
4942
4943
4944
4945
4946
4947
4948
4949
4950
4951
4952
                        ctypes.c_int(self.__num_inner_eval),
                        ctypes.byref(tmp_out_len),
                        ctypes.c_size_t(actual_string_buffer_size),
                        ctypes.byref(required_string_buffer_size),
                        ptr_string_buffers))
                self.__name_inner_eval = [
                    string_buffers[i].value.decode('utf-8') for i in range(self.__num_inner_eval)
                ]
                self.__higher_better_inner_eval = [
                    name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval
                ]